From fe6e7581d83d7ce3a878a6f89241aed23ecb1da8 Mon Sep 17 00:00:00 2001 From: Emmanuel Dias Date: Sun, 8 Dec 2024 18:09:01 -0300 Subject: [PATCH 1/8] add performance test --- tests/performance/array_pr_auc.xml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 tests/performance/array_pr_auc.xml diff --git a/tests/performance/array_pr_auc.xml b/tests/performance/array_pr_auc.xml new file mode 100644 index 00000000000..cd5155da992 --- /dev/null +++ b/tests/performance/array_pr_auc.xml @@ -0,0 +1,3 @@ + + SELECT avg(ifNotFinite(arrayPrAUC(arrayMap(x -> rand(x) / 0x100000000, range(2 + rand() % 100)), arrayMap(x -> rand(x) % 2, range(2 + rand() % 100))), 0)) FROM numbers(100000) + From a2dad7be45bcfba242a3650b5fd56b1c42a20c44 Mon Sep 17 00:00:00 2001 From: Emmanuel Dias Date: Mon, 9 Dec 2024 10:14:17 -0300 Subject: [PATCH 2/8] add alias --- ci/jobs/scripts/check_style/aspell-ignore/en/aspell-dict.txt | 2 ++ docs/en/sql-reference/functions/array-functions.md | 2 ++ src/Functions/array/arrayPrAUC.cpp | 1 + tests/queries/0_stateless/03272_array_pr_auc.reference | 1 + tests/queries/0_stateless/03272_array_pr_auc.sql | 3 +++ utils/check-style/aspell-ignore/en/aspell-dict.txt | 1 + 6 files changed, 10 insertions(+) diff --git a/ci/jobs/scripts/check_style/aspell-ignore/en/aspell-dict.txt b/ci/jobs/scripts/check_style/aspell-ignore/en/aspell-dict.txt index bb2b85976b2..eb51c578293 100644 --- a/ci/jobs/scripts/check_style/aspell-ignore/en/aspell-dict.txt +++ b/ci/jobs/scripts/check_style/aspell-ignore/en/aspell-dict.txt @@ -1200,6 +1200,8 @@ arrayPartialShuffle arrayPartialSort arrayPopBack arrayPopFront +arrayPRAUC +arrayPrAUC arrayProduct arrayPushBack arrayPushFront diff --git a/docs/en/sql-reference/functions/array-functions.md b/docs/en/sql-reference/functions/array-functions.md index 50c5cd33804..f316a8e96a0 100644 --- a/docs/en/sql-reference/functions/array-functions.md +++ b/docs/en/sql-reference/functions/array-functions.md @@ -2185,6 +2185,8 @@ Calculate AUC (Area Under the Curve) for the Precision Recall curve. arrayPrAUC(arr_scores, arr_labels) ``` +Alias: `arrayPRAUC` + **Arguments** - `arr_scores` — scores prediction model gives. diff --git a/src/Functions/array/arrayPrAUC.cpp b/src/Functions/array/arrayPrAUC.cpp index 57f0f447bbd..c60649a020d 100644 --- a/src/Functions/array/arrayPrAUC.cpp +++ b/src/Functions/array/arrayPrAUC.cpp @@ -241,6 +241,7 @@ public: REGISTER_FUNCTION(ArrayPrAUC) { factory.registerFunction(); + factory.registerAlias("arrayPRAUC", "arrayPrAUC"); } } diff --git a/tests/queries/0_stateless/03272_array_pr_auc.reference b/tests/queries/0_stateless/03272_array_pr_auc.reference index ee37e657593..35c6eea8405 100644 --- a/tests/queries/0_stateless/03272_array_pr_auc.reference +++ b/tests/queries/0_stateless/03272_array_pr_auc.reference @@ -13,6 +13,7 @@ 0.5 0.5 0.8333333333 +0.8333333333 0.8055555555 0.5 0.3666666666 diff --git a/tests/queries/0_stateless/03272_array_pr_auc.sql b/tests/queries/0_stateless/03272_array_pr_auc.sql index dfdcc5b5fe4..e1572d091ee 100644 --- a/tests/queries/0_stateless/03272_array_pr_auc.sql +++ b/tests/queries/0_stateless/03272_array_pr_auc.sql @@ -14,6 +14,9 @@ select floor(arrayPrAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1] select floor(arrayPrAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]), 10); select floor(arrayPrAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]), 10); +-- alias test +select floor(arrayPRAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]), 10); + -- output value correctness test select floor(arrayPrAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]), 10); select floor(arrayPrAUC([0.1, 0.4, 0.4, 0.35, 0.8], [0, 0, 1, 1, 1]), 10); diff --git a/utils/check-style/aspell-ignore/en/aspell-dict.txt b/utils/check-style/aspell-ignore/en/aspell-dict.txt index baaec790fd0..77a1799f8e8 100644 --- a/utils/check-style/aspell-ignore/en/aspell-dict.txt +++ b/utils/check-style/aspell-ignore/en/aspell-dict.txt @@ -1272,6 +1272,7 @@ arrayPartialShuffle arrayPartialSort arrayPopBack arrayPopFront +arrayPRAUC arrayPrAUC arrayProduct arrayPushBack From 78803f05cb4e420d8a678a932100905c473acfd5 Mon Sep 17 00:00:00 2001 From: Emmanuel Dias Date: Tue, 10 Dec 2024 11:35:58 -0300 Subject: [PATCH 3/8] improve docs --- docs/en/sql-reference/functions/array-functions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/en/sql-reference/functions/array-functions.md b/docs/en/sql-reference/functions/array-functions.md index f316a8e96a0..0a779087a11 100644 --- a/docs/en/sql-reference/functions/array-functions.md +++ b/docs/en/sql-reference/functions/array-functions.md @@ -2177,7 +2177,7 @@ Result: ## arrayPrAUC -Calculate AUC (Area Under the Curve) for the Precision Recall curve. +Calculate the area under the precision-recall (PR) curve. A precision-recall curve is created by plotting precision on the y-axis and recall on the x-axis across all thresholds. The resulting value ranges from 0 to 1, with a higher value indicating better model performance. PR AUC is particularly useful for imbalanced datasets, providing a clearer comparison of performance compared to ROC AUC on those cases. For more details on what it is and when to use it, refer to and . **Syntax** From faf5748a79b0b005af0a1f64b91bed85399ea1bc Mon Sep 17 00:00:00 2001 From: Robert Schulze Date: Fri, 13 Dec 2024 12:30:09 +0000 Subject: [PATCH 4/8] Some fixups --- .../aspell-ignore/en/aspell-dict.txt | 3 +- .../functions/array-functions.md | 19 +- src/Functions/array/arrayAUC.cpp | 190 ++++++++++---- src/Functions/array/arrayPrAUC.cpp | 247 ------------------ .../{array_pr_auc.xml => arrayAUCPr.xml} | 2 +- ...auc.reference => 01064_arrayAUC.reference} | 0 ...01064_array_auc.sql => 01064_arrayAUC.sql} | 0 ...rence => 01202_arrayAUC_special.reference} | 0 ...special.sql => 01202_arrayAUC_special.sql} | 0 ...c.reference => 03272_arrayAUCPr.reference} | 1 - .../queries/0_stateless/03272_arrayAUCPr.sql | 48 ++++ .../0_stateless/03272_array_pr_auc.sql | 52 ---- .../aspell-ignore/en/aspell-dict.txt | 19 +- 13 files changed, 213 insertions(+), 368 deletions(-) delete mode 100644 src/Functions/array/arrayPrAUC.cpp rename tests/performance/{array_pr_auc.xml => arrayAUCPr.xml} (69%) rename tests/queries/0_stateless/{01064_array_auc.reference => 01064_arrayAUC.reference} (100%) rename tests/queries/0_stateless/{01064_array_auc.sql => 01064_arrayAUC.sql} (100%) rename tests/queries/0_stateless/{01202_array_auc_special.reference => 01202_arrayAUC_special.reference} (100%) rename tests/queries/0_stateless/{01202_array_auc_special.sql => 01202_arrayAUC_special.sql} (100%) rename tests/queries/0_stateless/{03272_array_pr_auc.reference => 03272_arrayAUCPr.reference} (95%) create mode 100644 tests/queries/0_stateless/03272_arrayAUCPr.sql delete mode 100644 tests/queries/0_stateless/03272_array_pr_auc.sql diff --git a/ci/jobs/scripts/check_style/aspell-ignore/en/aspell-dict.txt b/ci/jobs/scripts/check_style/aspell-ignore/en/aspell-dict.txt index eb51c578293..b26163362cf 100644 --- a/ci/jobs/scripts/check_style/aspell-ignore/en/aspell-dict.txt +++ b/ci/jobs/scripts/check_style/aspell-ignore/en/aspell-dict.txt @@ -1161,6 +1161,7 @@ argMin argmax argmin arrayAUC +arrayAUCPr arrayAll arrayAvg arrayCompact @@ -1200,8 +1201,6 @@ arrayPartialShuffle arrayPartialSort arrayPopBack arrayPopFront -arrayPRAUC -arrayPrAUC arrayProduct arrayPushBack arrayPushFront diff --git a/docs/en/sql-reference/functions/array-functions.md b/docs/en/sql-reference/functions/array-functions.md index 5667e8a39cc..caf6e877dfc 100644 --- a/docs/en/sql-reference/functions/array-functions.md +++ b/docs/en/sql-reference/functions/array-functions.md @@ -2144,7 +2144,8 @@ Result: ## arrayAUC -Calculate AUC (Area Under the Curve, which is a concept in machine learning, see more details: ). +Calculates the Area Under the Curve (AUC), which is a concept in machine learning. +For more details, please see [here](https://developers.google.com/machine-learning/glossary#pr-auc-area-under-the-pr-curve), [here](https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc#expandable-1) and [here](https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve). **Syntax** @@ -2178,18 +2179,20 @@ Result: └───────────────────────────────────────────────┘ ``` -## arrayPrAUC +## arrayAUCPr -Calculate the area under the precision-recall (PR) curve. A precision-recall curve is created by plotting precision on the y-axis and recall on the x-axis across all thresholds. The resulting value ranges from 0 to 1, with a higher value indicating better model performance. PR AUC is particularly useful for imbalanced datasets, providing a clearer comparison of performance compared to ROC AUC on those cases. For more details on what it is and when to use it, refer to and . +Calculate the area under the precision-recall (PR) curve. +A precision-recall curve is created by plotting precision on the y-axis and recall on the x-axis across all thresholds. +The resulting value ranges from 0 to 1, with a higher value indicating better model performance. +PR AUC is particularly useful for imbalanced datasets, providing a clearer comparison of performance compared to ROC AUC on those cases. +For more details, please see [here](https://developers.google.com/machine-learning/glossary#pr-auc-area-under-the-pr-curve), [here](https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc#expandable-1) and [here](https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve). **Syntax** ``` sql -arrayPrAUC(arr_scores, arr_labels) +arrayAUCPr(arr_scores, arr_labels) ``` -Alias: `arrayPRAUC` - **Arguments** - `arr_scores` — scores prediction model gives. @@ -2204,13 +2207,13 @@ Returns PR-AUC value with type Float64. Query: ``` sql -select arrayPrAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); +select arrayAUCPr([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); ``` Result: ``` text -┌─arrayPrAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐ +┌─arrayAUCPr([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐ │ 0.8333333333333333 │ └─────────────────────────────────────────────────┘ ``` diff --git a/src/Functions/array/arrayAUC.cpp b/src/Functions/array/arrayAUC.cpp index 109c1e3c38d..7c37a1bd5ff 100644 --- a/src/Functions/array/arrayAUC.cpp +++ b/src/Functions/array/arrayAUC.cpp @@ -19,7 +19,7 @@ namespace ErrorCodes /** The function takes two arrays: scores and labels. - * Label can be one of two values: positive and negative. + * Label can be one of two values: positive (> 0) and negative (<= 0) * Score can be arbitrary number. * * These values are considered as the output of classifier. We have some true labels for objects. @@ -33,6 +33,8 @@ namespace ErrorCodes * or have false positive or false negative result. * Verying the threshold we can get different probabilities of false positive or false negatives or true positives, etc... * + * --------------------------------------------------------------------------------------------------------------------- + * * We can also calculate the True Positive Rate and the False Positive Rate: * * TPR (also called "sensitivity", "recall" or "probability of detection") @@ -73,13 +75,53 @@ namespace ErrorCodes * threshold = 0.8, TPR = 0, FPR = 0, TPR_raw = 0, FPR_raw = 0 * * The "curve" will be present by a line that moves one step either towards right or top on each threshold change. + * + * --------------------------------------------------------------------------------------------------------------------- + * + * We can also calculate the Precision and the Recall ("Pr"): + * + * Precision is the ratio `tp / (tp + fp)` where `tp` is the number of true positives and `fp` the number of false positives. + * It represents how often the classifier is correct when giving a positive result. + * Precision = P(label = positive | score > threshold) + * + * Recall is the ratio `tp / (tp + fn)` where `tp` is the number of true positives and `fn` the number of false negatives. + * It represents the probability of the classifier to give positive result if the object has positive label. + * Recall = P(score > threshold | label = positive) + * + * We can draw a curve of values of Precision and Recall with different threshold on [0..1] x [0..1] unit square. + * This curve is named "Precision Recall curve" (PR). + * + * For the curve we can calculate, literally, Area Under the Curve, that will be in the range of [0..1]. + * + * Let's look at the example: + * arrayPrAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); + * + * 1. We have pairs: (-, 0.1), (-, 0.4), (+, 0.35), (+, 0.8) + * + * 2. Let's sort by score descending: (+, 0.8), (-, 0.4), (+, 0.35), (-, 0.1) + * + * 3. Let's draw the points: + * + * threshold = 0.8, TP = 0, FP = 0, FN = 2, Recall = 0.0, Precision = 1 + * threshold = 0.4, TP = 1, FP = 0, FN = 1, Recall = 0.5, Precision = 1 + * threshold = 0.35, TP = 1, FP = 1, FN = 1, Recall = 0.5, Precision = 0.5 + * threshold = 0.1, TP = 2, FP = 1, FN = 0, Recall = 1.0, Precision = 0.666 + * threshold = 0, TP = 2, FP = 2, FN = 0, Recall = 1.0, Precision = 0.5 + * + * This implementation uses the right Riemann sum (see https://en.wikipedia.org/wiki/Riemann_sum) to calculate the AUC. + * That is, each increment in area is calculated using `(R_n - R_{n-1}) * P_n`, + * where `R_n` is the Recall at the `n`-th point and `P_n` is the Precision at the `n`-th point. + * + * This implementation is not interpolated and is different from computing the AUC with the trapezoidal rule, + * which uses linear interpolation and can be too optimistic for the Precision Recall AUC metric. */ +template class FunctionArrayAUC : public IFunction { public: - static constexpr auto name = "arrayAUC"; - static FunctionPtr create(ContextPtr) { return std::make_shared(); } + static constexpr auto name = Pr ? "arrayAUCPr" : "arrayAUC"; + static FunctionPtr create(ContextPtr) { return std::make_shared>(); } private: static Float64 apply( @@ -87,7 +129,7 @@ private: const IColumn & labels, ColumnArray::Offset current_offset, ColumnArray::Offset next_offset, - bool scale) + [[maybe_unused]] bool scale) { struct ScoreLabel { @@ -96,54 +138,109 @@ private: }; size_t size = next_offset - current_offset; + + if (Pr && size == 0) + return 0.0; + PODArrayWithStackMemory sorted_labels(size); for (size_t i = 0; i < size; ++i) { - bool label = labels.getFloat64(current_offset + i) > 0; + sorted_labels[i].label = labels.getFloat64(current_offset + i) > 0; sorted_labels[i].score = scores.getFloat64(current_offset + i); - sorted_labels[i].label = label; } - /// Sorting scores in descending order to traverse the ROC curve from left to right + /// Sorting scores in descending order to traverse the ROC / Precision-Recall curve from left to right std::sort(sorted_labels.begin(), sorted_labels.end(), [](const auto & lhs, const auto & rhs) { return lhs.score > rhs.score; }); - /// We will first calculate non-normalized area. - - Float64 area = 0.0; - Float64 prev_score = sorted_labels[0].score; - size_t prev_fp = 0; - size_t prev_tp = 0; - size_t curr_fp = 0; - size_t curr_tp = 0; - for (size_t i = 0; i < size; ++i) + if constexpr (!Pr) { - /// Only increment the area when the score changes - if (sorted_labels[i].score != prev_score) + /// We will first calculate non-normalized area. + Float64 area = 0.0; + Float64 prev_score = sorted_labels[0].score; + + size_t prev_fp = 0; + size_t prev_tp = 0; + size_t curr_fp = 0; + size_t curr_tp = 0; + + for (size_t i = 0; i < size; ++i) { - area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; /// Trapezoidal area under curve (might degenerate to zero or to a rectangle) - prev_fp = curr_fp; - prev_tp = curr_tp; - prev_score = sorted_labels[i].score; + /// Only increment the area when the score changes + if (sorted_labels[i].score != prev_score) + { + area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; /// Trapezoidal area under curve (might degenerate to zero or to a rectangle) + prev_fp = curr_fp; + prev_tp = curr_tp; + prev_score = sorted_labels[i].score; + } + + if (sorted_labels[i].label) + curr_tp += 1; /// The curve moves one step up. + else + curr_fp += 1; /// The curve moves one step right. } - if (sorted_labels[i].label) - curr_tp += 1; /// The curve moves one step up. - else - curr_fp += 1; /// The curve moves one step right. + area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; + + /// Then normalize it, if scale is true, dividing by the area to the area of rectangle. + + if (scale) + { + if (curr_tp == 0 || curr_tp == size) + return std::numeric_limits::quiet_NaN(); + return area / curr_tp / (size - curr_tp); + } + return area; } - - area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; - - /// Then normalize it, if scale is true, dividing by the area to the area of rectangle. - - if (scale) + else { - if (curr_tp == 0 || curr_tp == size) - return std::numeric_limits::quiet_NaN(); - return area / curr_tp / (size - curr_tp); + Float64 area = 0.0; + Float64 prev_score = sorted_labels[0].score; + + size_t prev_tp = 0; + size_t curr_tp = 0; /// True positives predictions (positive label and score > threshold) + size_t curr_p = 0; /// Total positive predictions (score > threshold) + Float64 curr_precision; + + for (size_t i = 0; i < size; ++i) + { + if (sorted_labels[i].score != prev_score) + { + /* Precision = TP / (TP + FP) + * Recall = TP / (TP + FN) + * + * Instead of calculating + * d_Area = Precision_n * (Recall_n - Recall_{n-1}), + * we can just calculate + * d_Area = Precision_n * (TP_n - TP_{n-1}) + * and later divide it by (TP + FN). + * + * This can be done because (TP + FN) is constant and equal to total positive labels. + */ + curr_precision = static_cast(curr_tp) / curr_p; /// curr_p should never be 0 because this if statement isn't executed on the first iteration and the + /// following iterations will have already counted (curr_p += 1) at least one positive prediction + area += curr_precision * (curr_tp - prev_tp); + prev_tp = curr_tp; + prev_score = sorted_labels[i].score; + } + + if (sorted_labels[i].label) + curr_tp += 1; + curr_p += 1; + } + + /// If there were no positive labels, Recall did not change and the area is 0 + if (curr_tp == 0) + return 0.0; + + curr_precision = curr_p > 0 ? static_cast(curr_tp) / curr_p : 1.0; + area += curr_precision * (curr_tp - prev_tp); + + /// Finally, we divide by (TP + FN) to obtain the Recall + /// At this point we've traversed the whole curve and curr_tp = total positive labels (TP + FN) + return area / curr_tp; } - return area; } static void vector( @@ -168,8 +265,8 @@ private: public: String getName() const override { return name; } - bool isVariadic() const override { return true; } - size_t getNumberOfArguments() const override { return 0; } + bool isVariadic() const override { return Pr ? false : true; } + size_t getNumberOfArguments() const override { return Pr ? 2 : 0; } bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } @@ -177,10 +274,11 @@ public: { size_t number_of_arguments = arguments.size(); - if (number_of_arguments < 2 || number_of_arguments > 3) + if ((!Pr && (number_of_arguments < 2 || number_of_arguments > 3)) + || (Pr && number_of_arguments != 2)) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, - "Number of arguments for function {} doesn't match: passed {}, should be 2 or 3", - getName(), number_of_arguments); + "Number of arguments for function {} doesn't match: passed {}, should be {}", + getName(), number_of_arguments, Pr ? "2" : "2 or 3"); for (size_t i = 0; i < 2; ++i) { @@ -193,7 +291,7 @@ public: throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{} cannot process values of type {}", getName(), nested_type->getName()); } - if (number_of_arguments == 3) + if (!Pr && number_of_arguments == 3) { if (!isBool(arguments[2].type) || arguments[2].column.get() == nullptr || !isColumnConst(*arguments[2].column)) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument (scale) for function {} must be of type const Bool.", getName()); @@ -202,10 +300,7 @@ public: return std::make_shared(); } - DataTypePtr getReturnTypeForDefaultImplementationForDynamic() const override - { - return std::make_shared(); - } + DataTypePtr getReturnTypeForDefaultImplementationForDynamic() const override { return std::make_shared(); } ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { @@ -249,7 +344,8 @@ public: REGISTER_FUNCTION(ArrayAUC) { - factory.registerFunction(); + factory.registerFunction>(); + factory.registerFunction>(); } } diff --git a/src/Functions/array/arrayPrAUC.cpp b/src/Functions/array/arrayPrAUC.cpp deleted file mode 100644 index c60649a020d..00000000000 --- a/src/Functions/array/arrayPrAUC.cpp +++ /dev/null @@ -1,247 +0,0 @@ -#include -#include -#include -#include -#include -#include - - -namespace DB -{ - -namespace ErrorCodes -{ -extern const int ILLEGAL_COLUMN; -extern const int BAD_ARGUMENTS; -extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; -} - - -/** The function takes two arrays: scores and labels. - * Label can be one of two values: positive (> 0) and negative (<= 0). - * Score can be arbitrary number. - * - * These values are considered as the output of classifier. We have some true labels for objects. - * And classifier assigns some scores to objects that predict these labels in the following way: - * - we can define arbitrary threshold on score and predict that the label is positive if the score is greater than the threshold: - * - * f(object) = score - * predicted_label = score > threshold - * - * This way classifier may predict positive or negative value correctly - true positive (tp) or true negative (tn) - * or have false positive (fp) or false negative (fn) result. - * Varying the threshold we can get different probabilities of false positive or false negatives or true positives, etc... - * - * We can also calculate the Precision and the Recall: - * - * Precision is the ratio `tp / (tp + fp)` where `tp` is the number of true positives and `fp` the number of false positives. - * It represents how often the classifier is correct when giving a positive result. - * Precision = P(label = positive | score > threshold) - * - * Recall is the ratio `tp / (tp + fn)` where `tp` is the number of true positives and `fn` the number of false negatives. - * It represents the probability of the classifier to give positive result if the object has positive label. - * Recall = P(score > threshold | label = positive) - * - * We can draw a curve of values of Precision and Recall with different threshold on [0..1] x [0..1] unit square. - * This curve is named "Precision Recall curve" (PR). - * - * For the curve we can calculate, literally, Area Under the Curve, that will be in the range of [0..1]. - * - * Let's look at the example: - * arrayPrAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); - * - * 1. We have pairs: (-, 0.1), (-, 0.4), (+, 0.35), (+, 0.8) - * - * 2. Let's sort by score descending: (+, 0.8), (-, 0.4), (+, 0.35), (-, 0.1) - * - * 3. Let's draw the points: - * - * threshold = 0.8, TP = 0, FP = 0, FN = 2, Recall = 0.0, Precision = 1 - * threshold = 0.4, TP = 1, FP = 0, FN = 1, Recall = 0.5, Precision = 1 - * threshold = 0.35, TP = 1, FP = 1, FN = 1, Recall = 0.5, Precision = 0.5 - * threshold = 0.1, TP = 2, FP = 1, FN = 0, Recall = 1.0, Precision = 0.666 - * threshold = 0, TP = 2, FP = 2, FN = 0, Recall = 1.0, Precision = 0.5 - * - * This implementation uses the right Riemann sum (see https://en.wikipedia.org/wiki/Riemann_sum) to calculate the AUC. - * That is, each increment in area is calculated using `(R_n - R_{n-1}) * P_n`, - * where `R_n` is the Recall at the `n`-th point and `P_n` is the Precision at the `n`-th point. - * - * This implementation is not interpolated and is different from computing the AUC with the trapezoidal rule, - * which uses linear interpolation and can be too optimistic for the Precision Recall AUC metric. - */ - -class FunctionArrayPrAUC : public IFunction -{ -public: - static constexpr auto name = "arrayPrAUC"; - static FunctionPtr create(ContextPtr) { return std::make_shared(); } - -private: - static Float64 apply(const IColumn & scores, const IColumn & labels, ColumnArray::Offset current_offset, ColumnArray::Offset next_offset) - { - size_t size = next_offset - current_offset; - if (size == 0) - return 0.0; - - struct ScoreLabel - { - Float64 score; - bool label; - }; - - PODArrayWithStackMemory sorted_labels(size); - - for (size_t i = 0; i < size; ++i) - { - sorted_labels[i].label = labels.getFloat64(current_offset + i) > 0; - sorted_labels[i].score = scores.getFloat64(current_offset + i); - } - - /// Sorting scores in descending order to traverse the Precision Recall curve from left to right - std::sort(sorted_labels.begin(), sorted_labels.end(), [](const auto & lhs, const auto & rhs) { return lhs.score > rhs.score; }); - - size_t prev_tp = 0; - size_t curr_tp = 0; /// True positives predictions (positive label and score > threshold) - size_t curr_p = 0; /// Total positive predictions (score > threshold) - - Float64 prev_score = sorted_labels[0].score; - Float64 curr_precision; - - Float64 area = 0.0; - - for (size_t i = 0; i < size; ++i) - { - if (sorted_labels[i].score != prev_score) - { - /* Precision = TP / (TP + FP) - * Recall = TP / (TP + FN) - * - * Instead of calculating - * d_Area = Precision_n * (Recall_n - Recall_{n-1}), - * we can just calculate - * d_Area = Precision_n * (TP_n - TP_{n-1}) - * and later divide it by (TP + FN). - * - * This can be done because (TP + FN) is constant and equal to total positive labels. - */ - curr_precision = static_cast(curr_tp) / curr_p; /// curr_p should never be 0 because this if statement isn't executed on the first iteration and the - /// following iterations will have already counted (curr_p += 1) at least one positive prediction - area += curr_precision * (curr_tp - prev_tp); - prev_tp = curr_tp; - prev_score = sorted_labels[i].score; - } - - if (sorted_labels[i].label) - curr_tp += 1; - curr_p += 1; - } - - /// If there were no positive labels, Recall did not change and the area is 0 - if (curr_tp == 0) - return 0.0; - - curr_precision = curr_p > 0 ? static_cast(curr_tp) / curr_p : 1.0; - area += curr_precision * (curr_tp - prev_tp); - - /// Finally, we divide by (TP + FN) to obtain the Recall - /// At this point we've traversed the whole curve and curr_tp = total positive labels (TP + FN) - return area / curr_tp; - } - - static void vector( - const IColumn & scores, - const IColumn & labels, - const ColumnArray::Offsets & offsets, - PaddedPODArray & result, - size_t input_rows_count) - { - result.resize(input_rows_count); - - ColumnArray::Offset current_offset = 0; - for (size_t i = 0; i < input_rows_count; ++i) - { - auto next_offset = offsets[i]; - result[i] = apply(scores, labels, current_offset, next_offset); - current_offset = next_offset; - } - } - -public: - String getName() const override { return name; } - - bool isVariadic() const override { return true; } - size_t getNumberOfArguments() const override { return 2; } - - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } - - DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override - { - if (arguments.size() != 2) - throw Exception( - ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, - "Number of arguments for function {} doesn't match: passed {}, should be 2.", - getName(), - arguments.size()); - - for (size_t i = 0; i < 2; ++i) - { - const DataTypeArray * array_type = checkAndGetDataType(arguments[i].type.get()); - if (!array_type) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "Both arguments for function {} must be of type Array", getName()); - - const auto & nested_type = array_type->getNestedType(); - - /// The first argument (scores) must be an array of numbers - if (i == 0 && !isNativeNumber(nested_type)) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "{} cannot process values of type {} in its first argument", getName(), nested_type->getName()); - - /// The second argument (labels) must be an array of numbers or enums - if (i == 1 && !isNativeNumber(nested_type) && !isEnum(nested_type)) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "{} cannot process values of type {} in its second argument", getName(), nested_type->getName()); - } - - return std::make_shared(); - } - - DataTypePtr getReturnTypeForDefaultImplementationForDynamic() const override { return std::make_shared(); } - - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override - { - ColumnPtr col1 = arguments[0].column->convertToFullColumnIfConst(); - ColumnPtr col2 = arguments[1].column->convertToFullColumnIfConst(); - - const ColumnArray * col_array1 = checkAndGetColumn(col1.get()); - if (!col_array1) - throw Exception( - ErrorCodes::ILLEGAL_COLUMN, - "Illegal column {} of first argument of function {}, should be an Array", - arguments[0].column->getName(), - getName()); - - const ColumnArray * col_array2 = checkAndGetColumn(col2.get()); - if (!col_array2) - throw Exception( - ErrorCodes::ILLEGAL_COLUMN, - "Illegal column {} of second argument of function {}, should be an Array", - arguments[1].column->getName(), - getName()); - - if (!col_array1->hasEqualOffsets(*col_array2)) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "Array arguments for function {} must have equal sizes", getName()); - - auto col_res = ColumnVector::create(); - - vector(col_array1->getData(), col_array2->getData(), col_array1->getOffsets(), col_res->getData(), input_rows_count); - - return col_res; - } -}; - - -REGISTER_FUNCTION(ArrayPrAUC) -{ - factory.registerFunction(); - factory.registerAlias("arrayPRAUC", "arrayPrAUC"); -} - -} diff --git a/tests/performance/array_pr_auc.xml b/tests/performance/arrayAUCPr.xml similarity index 69% rename from tests/performance/array_pr_auc.xml rename to tests/performance/arrayAUCPr.xml index cd5155da992..94bef3c4d3d 100644 --- a/tests/performance/array_pr_auc.xml +++ b/tests/performance/arrayAUCPr.xml @@ -1,3 +1,3 @@ - SELECT avg(ifNotFinite(arrayPrAUC(arrayMap(x -> rand(x) / 0x100000000, range(2 + rand() % 100)), arrayMap(x -> rand(x) % 2, range(2 + rand() % 100))), 0)) FROM numbers(100000) + SELECT avg(ifNotFinite(arrayAUCPr(arrayMap(x -> rand(x) / 0x100000000, range(2 + rand() % 100)), arrayMap(x -> rand(x) % 2, range(2 + rand() % 100))), 0)) FROM numbers(100000) diff --git a/tests/queries/0_stateless/01064_array_auc.reference b/tests/queries/0_stateless/01064_arrayAUC.reference similarity index 100% rename from tests/queries/0_stateless/01064_array_auc.reference rename to tests/queries/0_stateless/01064_arrayAUC.reference diff --git a/tests/queries/0_stateless/01064_array_auc.sql b/tests/queries/0_stateless/01064_arrayAUC.sql similarity index 100% rename from tests/queries/0_stateless/01064_array_auc.sql rename to tests/queries/0_stateless/01064_arrayAUC.sql diff --git a/tests/queries/0_stateless/01202_array_auc_special.reference b/tests/queries/0_stateless/01202_arrayAUC_special.reference similarity index 100% rename from tests/queries/0_stateless/01202_array_auc_special.reference rename to tests/queries/0_stateless/01202_arrayAUC_special.reference diff --git a/tests/queries/0_stateless/01202_array_auc_special.sql b/tests/queries/0_stateless/01202_arrayAUC_special.sql similarity index 100% rename from tests/queries/0_stateless/01202_array_auc_special.sql rename to tests/queries/0_stateless/01202_arrayAUC_special.sql diff --git a/tests/queries/0_stateless/03272_array_pr_auc.reference b/tests/queries/0_stateless/03272_arrayAUCPr.reference similarity index 95% rename from tests/queries/0_stateless/03272_array_pr_auc.reference rename to tests/queries/0_stateless/03272_arrayAUCPr.reference index 35c6eea8405..ee37e657593 100644 --- a/tests/queries/0_stateless/03272_array_pr_auc.reference +++ b/tests/queries/0_stateless/03272_arrayAUCPr.reference @@ -13,7 +13,6 @@ 0.5 0.5 0.8333333333 -0.8333333333 0.8055555555 0.5 0.3666666666 diff --git a/tests/queries/0_stateless/03272_arrayAUCPr.sql b/tests/queries/0_stateless/03272_arrayAUCPr.sql new file mode 100644 index 00000000000..d7d5928d788 --- /dev/null +++ b/tests/queries/0_stateless/03272_arrayAUCPr.sql @@ -0,0 +1,48 @@ +-- type correctness tests +select floor(arrayAUCPr([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]), 10); +select floor(arrayAUCPr([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8))), 10); +select floor(arrayAUCPr([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8))), 10); +select floor(arrayAUCPr([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1)))), 10); +select floor(arrayAUCPr([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1)))), 10); +select floor(arrayAUCPr(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPr(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPr(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPr(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPr(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPr(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPr(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPr(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPr(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]), 10); + +-- output value correctness test +select floor(arrayAUCPr([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]), 10); +select floor(arrayAUCPr([0.1, 0.4, 0.4, 0.35, 0.8], [0, 0, 1, 1, 1]), 10); +select floor(arrayAUCPr([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]), 10); +select floor(arrayAUCPr([0.1, 0.35, 0.4, 0.4, 0.8], [1, 0, 1, 0, 0]), 10); +select floor(arrayAUCPr([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]), 10); +select floor(arrayAUCPr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 0, 1, 0, 0, 0, 1, 0, 0, 1]), 10); +select floor(arrayAUCPr([0, 1, 1, 2, 2, 2, 3, 3, 3, 3], [1, 0, 1, 0, 0, 0, 1, 0, 0, 1]), 10); + +-- edge cases +SELECT floor(arrayAUCPr([1], [1]), 10); +SELECT floor(arrayAUCPr([1], [0]), 10); +SELECT floor(arrayAUCPr([0], [0]), 10); +SELECT floor(arrayAUCPr([0], [1]), 10); +SELECT floor(arrayAUCPr([1, 1], [1, 1]), 10); +SELECT floor(arrayAUCPr([1, 1], [0, 0]), 10); +SELECT floor(arrayAUCPr([1, 1], [0, 1]), 10); +SELECT floor(arrayAUCPr([0, 1], [0, 1]), 10); +SELECT floor(arrayAUCPr([1, 0], [0, 1]), 10); +SELECT floor(arrayAUCPr([0, 0, 1], [0, 1, 1]), 10); +SELECT floor(arrayAUCPr([0, 1, 1], [0, 1, 1]), 10); +SELECT floor(arrayAUCPr([0, 1, 1], [0, 0, 1]), 10); + +-- negative tests +select arrayAUCPr([], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +select arrayAUCPr([0, 0, 1, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +select arrayAUCPr([0.1, 0.35], [0, 0, 1, 1]); -- { serverError BAD_ARGUMENTS } +select arrayAUCPr([0.1, 0.4, 0.35, 0.8], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +select arrayAUCPr([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], [1, 1, 0, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +select arrayAUCPr(['a', 'b', 'c', 'd'], [1, 0, 1, 1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +select arrayAUCPr([0.1, 0.4, NULL, 0.8], [0, 0, 1, 1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +select arrayAUCPr([0.1, 0.4, 0.35, 0.8], [0, NULL, 1, 1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } diff --git a/tests/queries/0_stateless/03272_array_pr_auc.sql b/tests/queries/0_stateless/03272_array_pr_auc.sql deleted file mode 100644 index e1572d091ee..00000000000 --- a/tests/queries/0_stateless/03272_array_pr_auc.sql +++ /dev/null @@ -1,52 +0,0 @@ --- type correctness tests -select floor(arrayPrAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]), 10); -select floor(arrayPrAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8))), 10); -select floor(arrayPrAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8))), 10); -select floor(arrayPrAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1)))), 10); -select floor(arrayPrAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1)))), 10); -select floor(arrayPrAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1]), 10); -select floor(arrayPrAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1]), 10); -select floor(arrayPrAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1]), 10); -select floor(arrayPrAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1]), 10); -select floor(arrayPrAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1]), 10); -select floor(arrayPrAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1]), 10); -select floor(arrayPrAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1]), 10); -select floor(arrayPrAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]), 10); -select floor(arrayPrAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]), 10); - --- alias test -select floor(arrayPRAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]), 10); - --- output value correctness test -select floor(arrayPrAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]), 10); -select floor(arrayPrAUC([0.1, 0.4, 0.4, 0.35, 0.8], [0, 0, 1, 1, 1]), 10); -select floor(arrayPrAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]), 10); -select floor(arrayPrAUC([0.1, 0.35, 0.4, 0.4, 0.8], [1, 0, 1, 0, 0]), 10); -select floor(arrayPrAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]), 10); -select floor(arrayPrAUC([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 0, 1, 0, 0, 0, 1, 0, 0, 1]), 10); -select floor(arrayPrAUC([0, 1, 1, 2, 2, 2, 3, 3, 3, 3], [1, 0, 1, 0, 0, 0, 1, 0, 0, 1]), 10); - --- edge cases -SELECT floor(arrayPrAUC([1], [1]), 10); -SELECT floor(arrayPrAUC([1], [0]), 10); -SELECT floor(arrayPrAUC([0], [0]), 10); -SELECT floor(arrayPrAUC([0], [1]), 10); -SELECT floor(arrayPrAUC([1, 1], [1, 1]), 10); -SELECT floor(arrayPrAUC([1, 1], [0, 0]), 10); -SELECT floor(arrayPrAUC([1, 1], [0, 1]), 10); -SELECT floor(arrayPrAUC([0, 1], [0, 1]), 10); -SELECT floor(arrayPrAUC([1, 0], [0, 1]), 10); -SELECT floor(arrayPrAUC([0, 0, 1], [0, 1, 1]), 10); -SELECT floor(arrayPrAUC([0, 1, 1], [0, 1, 1]), 10); -SELECT floor(arrayPrAUC([0, 1, 1], [0, 0, 1]), 10); - --- negative tests -select arrayPrAUC([], []); -- { serverError BAD_ARGUMENTS } -select arrayPrAUC([0, 0, 1, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } -select arrayPrAUC([0.1, 0.35], [0, 0, 1, 1]); -- { serverError BAD_ARGUMENTS } -select arrayPrAUC([0.1, 0.4, 0.35, 0.8], []); -- { serverError BAD_ARGUMENTS } -select arrayPrAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], [1, 1, 0, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } -select arrayPrAUC(cast(['false', 'true'] as Array(Enum8('false' = -1, 'true' = 1))), [1, 0]); -- { serverError BAD_ARGUMENTS } -select arrayPrAUC(['a', 'b', 'c', 'd'], [1, 0, 1, 1]); -- { serverError BAD_ARGUMENTS } -select arrayPrAUC([0.1, 0.4, NULL, 0.8], [0, 0, 1, 1]); -- { serverError BAD_ARGUMENTS } -select arrayPrAUC([0.1, 0.4, 0.35, 0.8], [0, NULL, 1, 1]); -- { serverError BAD_ARGUMENTS } diff --git a/utils/check-style/aspell-ignore/en/aspell-dict.txt b/utils/check-style/aspell-ignore/en/aspell-dict.txt index 77a1799f8e8..f34ab458d00 100644 --- a/utils/check-style/aspell-ignore/en/aspell-dict.txt +++ b/utils/check-style/aspell-ignore/en/aspell-dict.txt @@ -49,6 +49,7 @@ AutoML Autocompletion AvroConfluent AzureQueue +BFloat BIGINT BIGSERIAL BORO @@ -244,10 +245,8 @@ Deduplication DefaultTableEngine DelayedInserts DeliveryTag -Deltalake DeltaLake -deltalakeCluster -deltaLakeCluster +Deltalake Denormalize DestroyAggregatesThreads DestroyAggregatesThreadsActive @@ -380,15 +379,11 @@ Homebrew's HorizontalDivide Hostname HouseOps -hudi Hudi -hudiCluster HudiCluster HyperLogLog Hypot IANA -icebergCluster -IcebergCluster IDE IDEs IDNA @@ -409,6 +404,7 @@ IPTrie IProcessor IPv ITION +IcebergCluster Identifiant IdentifierQuotingRule IdentifierQuotingStyle @@ -1233,6 +1229,7 @@ argMin argmax argmin arrayAUC +arrayAUCPr arrayAll arrayAvg arrayCompact @@ -1272,8 +1269,6 @@ arrayPartialShuffle arrayPartialSort arrayPopBack arrayPopFront -arrayPRAUC -arrayPrAUC arrayProduct arrayPushBack arrayPushFront @@ -1618,9 +1613,11 @@ defaultValueOfArgumentType defaultValueOfTypeName delim deltaLake +deltaLakeCluster deltaSum deltaSumTimestamp deltalake +deltalakeCluster deltasum deltasumtimestamp demangle @@ -1940,10 +1937,13 @@ html http https hudi +hudi +hudiCluster hyperscan hypot hyvor iTerm +icebergCluster icosahedron icudata idempotency @@ -3168,4 +3168,3 @@ znode znodes zookeeperSessionUptime zstd -BFloat From ebbc7bb542b957c54f422d47c14cd22416b08a90 Mon Sep 17 00:00:00 2001 From: Robert Schulze Date: Fri, 13 Dec 2024 14:26:48 +0000 Subject: [PATCH 5/8] Fix FastTest --- .../02415_all_new_functions_must_be_documented.reference | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference b/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference index e6a17a49ec9..4365a994b7a 100644 --- a/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference +++ b/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference @@ -91,6 +91,7 @@ and appendTrailingCharIfAbsent array arrayAUC +arrayAUCPr arrayAll arrayAvg arrayCompact @@ -124,7 +125,6 @@ arrayMax arrayMin arrayPopBack arrayPopFront -arrayPrAUC arrayProduct arrayPushBack arrayPushFront From f1d701ca07fb4150d63d8d4c6f047b2b41aae079 Mon Sep 17 00:00:00 2001 From: Robert Schulze Date: Sun, 15 Dec 2024 14:44:24 +0000 Subject: [PATCH 6/8] Integrate #71910 --- .../functions/array-functions.md | 22 +++---- src/Functions/array/arrayAUC.cpp | 25 ++++---- .../{arrayAUCPr.xml => arrayAUCPR.xml} | 2 +- tests/performance/arrayROCAUC.xml | 4 ++ tests/performance/array_auc.xml | 4 -- tests/queries/0_stateless/01064_arrayAUC.sql | 56 ------------------ ....reference => 01064_arrayROCAUC.reference} | 1 + .../queries/0_stateless/01064_arrayROCAUC.sql | 59 +++++++++++++++++++ .../0_stateless/01202_arrayAUC_special.sql | 46 --------------- ...ce => 01202_arrayROCAUC_special.reference} | 0 .../0_stateless/01202_arrayROCAUC_special.sql | 46 +++++++++++++++ ...r.reference => 03272_arrayAUCPR.reference} | 0 .../queries/0_stateless/03272_arrayAUCPR.sql | 48 +++++++++++++++ .../queries/0_stateless/03272_arrayAUCPr.sql | 48 --------------- .../aspell-ignore/en/aspell-dict.txt | 3 +- 15 files changed, 186 insertions(+), 178 deletions(-) rename tests/performance/{arrayAUCPr.xml => arrayAUCPR.xml} (69%) create mode 100644 tests/performance/arrayROCAUC.xml delete mode 100644 tests/performance/array_auc.xml delete mode 100644 tests/queries/0_stateless/01064_arrayAUC.sql rename tests/queries/0_stateless/{01064_arrayAUC.reference => 01064_arrayROCAUC.reference} (98%) create mode 100644 tests/queries/0_stateless/01064_arrayROCAUC.sql delete mode 100644 tests/queries/0_stateless/01202_arrayAUC_special.sql rename tests/queries/0_stateless/{01202_arrayAUC_special.reference => 01202_arrayROCAUC_special.reference} (100%) create mode 100644 tests/queries/0_stateless/01202_arrayROCAUC_special.sql rename tests/queries/0_stateless/{03272_arrayAUCPr.reference => 03272_arrayAUCPR.reference} (100%) create mode 100644 tests/queries/0_stateless/03272_arrayAUCPR.sql delete mode 100644 tests/queries/0_stateless/03272_arrayAUCPr.sql diff --git a/docs/en/sql-reference/functions/array-functions.md b/docs/en/sql-reference/functions/array-functions.md index caf6e877dfc..a93352dbe9e 100644 --- a/docs/en/sql-reference/functions/array-functions.md +++ b/docs/en/sql-reference/functions/array-functions.md @@ -2142,7 +2142,7 @@ Result: ``` -## arrayAUC +## arrayROCAUC Calculates the Area Under the Curve (AUC), which is a concept in machine learning. For more details, please see [here](https://developers.google.com/machine-learning/glossary#pr-auc-area-under-the-pr-curve), [here](https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc#expandable-1) and [here](https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve). @@ -2150,9 +2150,11 @@ For more details, please see [here](https://developers.google.com/machine-learni **Syntax** ``` sql -arrayAUC(arr_scores, arr_labels[, scale]) +arrayROCAUC(arr_scores, arr_labels[, scale]) ``` +Alias: `arrayAUC` + **Arguments** - `arr_scores` — scores prediction model gives. @@ -2168,18 +2170,18 @@ Returns AUC value with type Float64. Query: ``` sql -select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); ``` Result: ``` text -┌─arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐ -│ 0.75 │ -└───────────────────────────────────────────────┘ +┌─arrayROCAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐ +│ 0.75 │ +└──────────────────────────────────────────────────┘ ``` -## arrayAUCPr +## arrayAUCPR Calculate the area under the precision-recall (PR) curve. A precision-recall curve is created by plotting precision on the y-axis and recall on the x-axis across all thresholds. @@ -2190,7 +2192,7 @@ For more details, please see [here](https://developers.google.com/machine-learni **Syntax** ``` sql -arrayAUCPr(arr_scores, arr_labels) +arrayAUCPR(arr_scores, arr_labels) ``` **Arguments** @@ -2207,13 +2209,13 @@ Returns PR-AUC value with type Float64. Query: ``` sql -select arrayAUCPr([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); +select arrayAUCPR([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); ``` Result: ``` text -┌─arrayAUCPr([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐ +┌─arrayAUCPR([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐ │ 0.8333333333333333 │ └─────────────────────────────────────────────────┘ ``` diff --git a/src/Functions/array/arrayAUC.cpp b/src/Functions/array/arrayAUC.cpp index 7c37a1bd5ff..c076addd9c5 100644 --- a/src/Functions/array/arrayAUC.cpp +++ b/src/Functions/array/arrayAUC.cpp @@ -78,7 +78,7 @@ namespace ErrorCodes * * --------------------------------------------------------------------------------------------------------------------- * - * We can also calculate the Precision and the Recall ("Pr"): + * We can also calculate the Precision and the Recall ("PR"): * * Precision is the ratio `tp / (tp + fp)` where `tp` is the number of true positives and `fp` the number of false positives. * It represents how often the classifier is correct when giving a positive result. @@ -116,12 +116,12 @@ namespace ErrorCodes * which uses linear interpolation and can be too optimistic for the Precision Recall AUC metric. */ -template +template class FunctionArrayAUC : public IFunction { public: - static constexpr auto name = Pr ? "arrayAUCPr" : "arrayAUC"; - static FunctionPtr create(ContextPtr) { return std::make_shared>(); } + static constexpr auto name = PR ? "arrayAUCPR" : "arrayROCAUC"; + static FunctionPtr create(ContextPtr) { return std::make_shared>(); } private: static Float64 apply( @@ -139,7 +139,7 @@ private: size_t size = next_offset - current_offset; - if (Pr && size == 0) + if (PR && size == 0) return 0.0; PODArrayWithStackMemory sorted_labels(size); @@ -153,7 +153,7 @@ private: /// Sorting scores in descending order to traverse the ROC / Precision-Recall curve from left to right std::sort(sorted_labels.begin(), sorted_labels.end(), [](const auto & lhs, const auto & rhs) { return lhs.score > rhs.score; }); - if constexpr (!Pr) + if constexpr (!PR) { /// We will first calculate non-normalized area. Float64 area = 0.0; @@ -265,8 +265,8 @@ private: public: String getName() const override { return name; } - bool isVariadic() const override { return Pr ? false : true; } - size_t getNumberOfArguments() const override { return Pr ? 2 : 0; } + bool isVariadic() const override { return PR ? false : true; } + size_t getNumberOfArguments() const override { return PR ? 2 : 0; } bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } @@ -274,11 +274,11 @@ public: { size_t number_of_arguments = arguments.size(); - if ((!Pr && (number_of_arguments < 2 || number_of_arguments > 3)) - || (Pr && number_of_arguments != 2)) + if ((!PR && (number_of_arguments < 2 || number_of_arguments > 3)) + || (PR && number_of_arguments != 2)) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Number of arguments for function {} doesn't match: passed {}, should be {}", - getName(), number_of_arguments, Pr ? "2" : "2 or 3"); + getName(), number_of_arguments, PR ? "2" : "2 or 3"); for (size_t i = 0; i < 2; ++i) { @@ -291,7 +291,7 @@ public: throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{} cannot process values of type {}", getName(), nested_type->getName()); } - if (!Pr && number_of_arguments == 3) + if (!PR && number_of_arguments == 3) { if (!isBool(arguments[2].type) || arguments[2].column.get() == nullptr || !isColumnConst(*arguments[2].column)) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument (scale) for function {} must be of type const Bool.", getName()); @@ -346,6 +346,7 @@ REGISTER_FUNCTION(ArrayAUC) { factory.registerFunction>(); factory.registerFunction>(); + factory.registerAlias("arrayAUC", "arrayROCAUC"); /// Backward compatibility, also ROC AUC is often shorted to just AUC } } diff --git a/tests/performance/arrayAUCPr.xml b/tests/performance/arrayAUCPR.xml similarity index 69% rename from tests/performance/arrayAUCPr.xml rename to tests/performance/arrayAUCPR.xml index 94bef3c4d3d..aa8dec9c2c9 100644 --- a/tests/performance/arrayAUCPr.xml +++ b/tests/performance/arrayAUCPR.xml @@ -1,3 +1,3 @@ - SELECT avg(ifNotFinite(arrayAUCPr(arrayMap(x -> rand(x) / 0x100000000, range(2 + rand() % 100)), arrayMap(x -> rand(x) % 2, range(2 + rand() % 100))), 0)) FROM numbers(100000) + SELECT avg(ifNotFinite(arrayAUCPR(arrayMap(x -> rand(x) / 0x100000000, range(2 + rand() % 100)), arrayMap(x -> rand(x) % 2, range(2 + rand() % 100))), 0)) FROM numbers(100000) diff --git a/tests/performance/arrayROCAUC.xml b/tests/performance/arrayROCAUC.xml new file mode 100644 index 00000000000..212b101fe99 --- /dev/null +++ b/tests/performance/arrayROCAUC.xml @@ -0,0 +1,4 @@ + + + SELECT avg(ifNotFinite(arrayROCAUC(arrayMap(x -> rand(x) / 0x100000000, range(2 + rand() % 100)), arrayMap(x -> rand(x) % 2, range(2 + rand() % 100))), 0)) FROM numbers(100000) + diff --git a/tests/performance/array_auc.xml b/tests/performance/array_auc.xml deleted file mode 100644 index 59d321b3c62..00000000000 --- a/tests/performance/array_auc.xml +++ /dev/null @@ -1,4 +0,0 @@ - - - SELECT avg(ifNotFinite(arrayAUC(arrayMap(x -> rand(x) / 0x100000000, range(2 + rand() % 100)), arrayMap(x -> rand(x) % 2, range(2 + rand() % 100))), 0)) FROM numbers(100000) - diff --git a/tests/queries/0_stateless/01064_arrayAUC.sql b/tests/queries/0_stateless/01064_arrayAUC.sql deleted file mode 100644 index 5594b505223..00000000000 --- a/tests/queries/0_stateless/01064_arrayAUC.sql +++ /dev/null @@ -1,56 +0,0 @@ -select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8))); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8))); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1)))); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1)))); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1]); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1]); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1]); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1]); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1]); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1]); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1]); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]); -select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]); -select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]); -select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]); - -select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], true); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), true); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), true); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))), true); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))), true); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1], true); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1], true); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1], true); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1], true); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1], true); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1], true); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1], true); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1], true); -select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1], true); -select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0], true); -select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0], true); - -select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], false); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), false); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), false); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))), false); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))), false); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1], false); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1], false); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1], false); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1], false); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1], false); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1], false); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1], false); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1], false); -select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1], false); -select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0], false); -select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0], false); - --- negative tests -select arrayAUC([0, 0, 1, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } -select arrayAUC([0.1, 0.35], [0, 0, 1, 1]); -- { serverError BAD_ARGUMENTS } -select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], materialize(true)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], true, true); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } diff --git a/tests/queries/0_stateless/01064_arrayAUC.reference b/tests/queries/0_stateless/01064_arrayROCAUC.reference similarity index 98% rename from tests/queries/0_stateless/01064_arrayAUC.reference rename to tests/queries/0_stateless/01064_arrayROCAUC.reference index 3fd5483eb99..6ba29986ba7 100644 --- a/tests/queries/0_stateless/01064_arrayAUC.reference +++ b/tests/queries/0_stateless/01064_arrayROCAUC.reference @@ -46,3 +46,4 @@ 1 1 1 +3 diff --git a/tests/queries/0_stateless/01064_arrayROCAUC.sql b/tests/queries/0_stateless/01064_arrayROCAUC.sql new file mode 100644 index 00000000000..b5dac4c3741 --- /dev/null +++ b/tests/queries/0_stateless/01064_arrayROCAUC.sql @@ -0,0 +1,59 @@ +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8))); +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8))); +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1)))); +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1)))); +select arrayROCAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1]); +select arrayROCAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1]); +select arrayROCAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1]); +select arrayROCAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1]); +select arrayROCAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1]); +select arrayROCAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1]); +select arrayROCAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1]); +select arrayROCAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]); +select arrayROCAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]); +select arrayROCAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]); +select arrayROCAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]); + +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], true); +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), true); +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), true); +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))), true); +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))), true); +select arrayROCAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1], true); +select arrayROCAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1], true); +select arrayROCAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1], true); +select arrayROCAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1], true); +select arrayROCAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1], true); +select arrayROCAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1], true); +select arrayROCAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1], true); +select arrayROCAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1], true); +select arrayROCAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1], true); +select arrayROCAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0], true); +select arrayROCAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0], true); + +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], false); +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), false); +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), false); +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))), false); +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))), false); +select arrayROCAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1], false); +select arrayROCAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1], false); +select arrayROCAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1], false); +select arrayROCAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1], false); +select arrayROCAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1], false); +select arrayROCAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1], false); +select arrayROCAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1], false); +select arrayROCAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1], false); +select arrayROCAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1], false); +select arrayROCAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0], false); +select arrayROCAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0], false); + +-- negative tests +select arrayROCAUC([0, 0, 1, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +select arrayROCAUC([0.1, 0.35], [0, 0, 1, 1]); -- { serverError BAD_ARGUMENTS } +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], materialize(true)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +select arrayROCAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], true, true); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } + +-- alias +select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], false); diff --git a/tests/queries/0_stateless/01202_arrayAUC_special.sql b/tests/queries/0_stateless/01202_arrayAUC_special.sql deleted file mode 100644 index a7276ec0620..00000000000 --- a/tests/queries/0_stateless/01202_arrayAUC_special.sql +++ /dev/null @@ -1,46 +0,0 @@ -SELECT arrayAUC([], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUC([1], [1]); -SELECT arrayAUC([1], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUC([], [1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUC([1, 2], [3]); -- { serverError BAD_ARGUMENTS } -SELECT arrayAUC([1], [2, 3]); -- { serverError BAD_ARGUMENTS } -SELECT arrayAUC([1, 1], [1, 1]); -SELECT arrayAUC([1, 1], [0, 0]); -SELECT arrayAUC([1, 1], [0, 1]); -SELECT arrayAUC([0, 1], [0, 1]); -SELECT arrayAUC([1, 0], [0, 1]); -SELECT arrayAUC([0, 0, 1], [0, 1, 1]); -SELECT arrayAUC([0, 1, 1], [0, 1, 1]); -SELECT arrayAUC([0, 1, 1], [0, 0, 1]); -SELECT arrayAUC([], [], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUC([1], [1], true); -SELECT arrayAUC([1], [], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUC([], [1], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUC([1, 2], [3], true); -- { serverError BAD_ARGUMENTS } -SELECT arrayAUC([1], [2, 3], true); -- { serverError BAD_ARGUMENTS } -SELECT arrayAUC([1, 1], [1, 1], true); -SELECT arrayAUC([1, 1], [0, 0], true); -SELECT arrayAUC([1, 1], [0, 1], true); -SELECT arrayAUC([0, 1], [0, 1], true); -SELECT arrayAUC([1, 0], [0, 1], true); -SELECT arrayAUC([0, 0, 1], [0, 1, 1], true); -SELECT arrayAUC([0, 1, 1], [0, 1, 1], true); -SELECT arrayAUC([0, 1, 1], [0, 0, 1], true); -SELECT arrayAUC([], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUC([1], [1], false); -SELECT arrayAUC([1], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUC([], [1], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUC([1, 2], [3], false); -- { serverError BAD_ARGUMENTS } -SELECT arrayAUC([1], [2, 3], false); -- { serverError BAD_ARGUMENTS } -SELECT arrayAUC([1, 1], [1, 1], false); -SELECT arrayAUC([1, 1], [0, 0], false); -SELECT arrayAUC([1, 1], [0, 1], false); -SELECT arrayAUC([0, 1], [0, 1], false); -SELECT arrayAUC([1, 0], [0, 1], false); -SELECT arrayAUC([0, 0, 1], [0, 1, 1], false); -SELECT arrayAUC([0, 1, 1], [0, 1, 1], false); -SELECT arrayAUC([0, 1, 1], [0, 0, 1], false); -SELECT arrayAUC([0, 1, 1], [0, 0, 1], false, true); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } -SELECT arrayAUC([0, 1, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } -SELECT arrayAUC([0, 1, 1], [0, 0, 1], 'false'); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUC([0, 1, 1], [0, 0, 1], 4); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } \ No newline at end of file diff --git a/tests/queries/0_stateless/01202_arrayAUC_special.reference b/tests/queries/0_stateless/01202_arrayROCAUC_special.reference similarity index 100% rename from tests/queries/0_stateless/01202_arrayAUC_special.reference rename to tests/queries/0_stateless/01202_arrayROCAUC_special.reference diff --git a/tests/queries/0_stateless/01202_arrayROCAUC_special.sql b/tests/queries/0_stateless/01202_arrayROCAUC_special.sql new file mode 100644 index 00000000000..8921ecead5c --- /dev/null +++ b/tests/queries/0_stateless/01202_arrayROCAUC_special.sql @@ -0,0 +1,46 @@ +SELECT arrayROCAUC([], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayROCAUC([1], [1]); +SELECT arrayROCAUC([1], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayROCAUC([], [1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayROCAUC([1, 2], [3]); -- { serverError BAD_ARGUMENTS } +SELECT arrayROCAUC([1], [2, 3]); -- { serverError BAD_ARGUMENTS } +SELECT arrayROCAUC([1, 1], [1, 1]); +SELECT arrayROCAUC([1, 1], [0, 0]); +SELECT arrayROCAUC([1, 1], [0, 1]); +SELECT arrayROCAUC([0, 1], [0, 1]); +SELECT arrayROCAUC([1, 0], [0, 1]); +SELECT arrayROCAUC([0, 0, 1], [0, 1, 1]); +SELECT arrayROCAUC([0, 1, 1], [0, 1, 1]); +SELECT arrayROCAUC([0, 1, 1], [0, 0, 1]); +SELECT arrayROCAUC([], [], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayROCAUC([1], [1], true); +SELECT arrayROCAUC([1], [], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayROCAUC([], [1], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayROCAUC([1, 2], [3], true); -- { serverError BAD_ARGUMENTS } +SELECT arrayROCAUC([1], [2, 3], true); -- { serverError BAD_ARGUMENTS } +SELECT arrayROCAUC([1, 1], [1, 1], true); +SELECT arrayROCAUC([1, 1], [0, 0], true); +SELECT arrayROCAUC([1, 1], [0, 1], true); +SELECT arrayROCAUC([0, 1], [0, 1], true); +SELECT arrayROCAUC([1, 0], [0, 1], true); +SELECT arrayROCAUC([0, 0, 1], [0, 1, 1], true); +SELECT arrayROCAUC([0, 1, 1], [0, 1, 1], true); +SELECT arrayROCAUC([0, 1, 1], [0, 0, 1], true); +SELECT arrayROCAUC([], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayROCAUC([1], [1], false); +SELECT arrayROCAUC([1], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayROCAUC([], [1], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayROCAUC([1, 2], [3], false); -- { serverError BAD_ARGUMENTS } +SELECT arrayROCAUC([1], [2, 3], false); -- { serverError BAD_ARGUMENTS } +SELECT arrayROCAUC([1, 1], [1, 1], false); +SELECT arrayROCAUC([1, 1], [0, 0], false); +SELECT arrayROCAUC([1, 1], [0, 1], false); +SELECT arrayROCAUC([0, 1], [0, 1], false); +SELECT arrayROCAUC([1, 0], [0, 1], false); +SELECT arrayROCAUC([0, 0, 1], [0, 1, 1], false); +SELECT arrayROCAUC([0, 1, 1], [0, 1, 1], false); +SELECT arrayROCAUC([0, 1, 1], [0, 0, 1], false); +SELECT arrayROCAUC([0, 1, 1], [0, 0, 1], false, true); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +SELECT arrayROCAUC([0, 1, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +SELECT arrayROCAUC([0, 1, 1], [0, 0, 1], 'false'); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayROCAUC([0, 1, 1], [0, 0, 1], 4); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } diff --git a/tests/queries/0_stateless/03272_arrayAUCPr.reference b/tests/queries/0_stateless/03272_arrayAUCPR.reference similarity index 100% rename from tests/queries/0_stateless/03272_arrayAUCPr.reference rename to tests/queries/0_stateless/03272_arrayAUCPR.reference diff --git a/tests/queries/0_stateless/03272_arrayAUCPR.sql b/tests/queries/0_stateless/03272_arrayAUCPR.sql new file mode 100644 index 00000000000..e89210dbe7a --- /dev/null +++ b/tests/queries/0_stateless/03272_arrayAUCPR.sql @@ -0,0 +1,48 @@ +-- type correctness tests +select floor(arrayAUCPR([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]), 10); +select floor(arrayAUCPR([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8))), 10); +select floor(arrayAUCPR([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8))), 10); +select floor(arrayAUCPR([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1)))), 10); +select floor(arrayAUCPR([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1)))), 10); +select floor(arrayAUCPR(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPR(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPR(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPR(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPR(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPR(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPR(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPR(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]), 10); +select floor(arrayAUCPR(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]), 10); + +-- output value correctness test +select floor(arrayAUCPR([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]), 10); +select floor(arrayAUCPR([0.1, 0.4, 0.4, 0.35, 0.8], [0, 0, 1, 1, 1]), 10); +select floor(arrayAUCPR([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]), 10); +select floor(arrayAUCPR([0.1, 0.35, 0.4, 0.4, 0.8], [1, 0, 1, 0, 0]), 10); +select floor(arrayAUCPR([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]), 10); +select floor(arrayAUCPR([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 0, 1, 0, 0, 0, 1, 0, 0, 1]), 10); +select floor(arrayAUCPR([0, 1, 1, 2, 2, 2, 3, 3, 3, 3], [1, 0, 1, 0, 0, 0, 1, 0, 0, 1]), 10); + +-- edge cases +SELECT floor(arrayAUCPR([1], [1]), 10); +SELECT floor(arrayAUCPR([1], [0]), 10); +SELECT floor(arrayAUCPR([0], [0]), 10); +SELECT floor(arrayAUCPR([0], [1]), 10); +SELECT floor(arrayAUCPR([1, 1], [1, 1]), 10); +SELECT floor(arrayAUCPR([1, 1], [0, 0]), 10); +SELECT floor(arrayAUCPR([1, 1], [0, 1]), 10); +SELECT floor(arrayAUCPR([0, 1], [0, 1]), 10); +SELECT floor(arrayAUCPR([1, 0], [0, 1]), 10); +SELECT floor(arrayAUCPR([0, 0, 1], [0, 1, 1]), 10); +SELECT floor(arrayAUCPR([0, 1, 1], [0, 1, 1]), 10); +SELECT floor(arrayAUCPR([0, 1, 1], [0, 0, 1]), 10); + +-- negative tests +select arrayAUCPR([], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +select arrayAUCPR([0, 0, 1, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +select arrayAUCPR([0.1, 0.35], [0, 0, 1, 1]); -- { serverError BAD_ARGUMENTS } +select arrayAUCPR([0.1, 0.4, 0.35, 0.8], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +select arrayAUCPR([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], [1, 1, 0, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +select arrayAUCPR(['a', 'b', 'c', 'd'], [1, 0, 1, 1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +select arrayAUCPR([0.1, 0.4, NULL, 0.8], [0, 0, 1, 1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +select arrayAUCPR([0.1, 0.4, 0.35, 0.8], [0, NULL, 1, 1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } diff --git a/tests/queries/0_stateless/03272_arrayAUCPr.sql b/tests/queries/0_stateless/03272_arrayAUCPr.sql deleted file mode 100644 index d7d5928d788..00000000000 --- a/tests/queries/0_stateless/03272_arrayAUCPr.sql +++ /dev/null @@ -1,48 +0,0 @@ --- type correctness tests -select floor(arrayAUCPr([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]), 10); -select floor(arrayAUCPr([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8))), 10); -select floor(arrayAUCPr([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8))), 10); -select floor(arrayAUCPr([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1)))), 10); -select floor(arrayAUCPr([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1)))), 10); -select floor(arrayAUCPr(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1]), 10); -select floor(arrayAUCPr(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1]), 10); -select floor(arrayAUCPr(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1]), 10); -select floor(arrayAUCPr(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1]), 10); -select floor(arrayAUCPr(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1]), 10); -select floor(arrayAUCPr(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1]), 10); -select floor(arrayAUCPr(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1]), 10); -select floor(arrayAUCPr(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]), 10); -select floor(arrayAUCPr(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]), 10); - --- output value correctness test -select floor(arrayAUCPr([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]), 10); -select floor(arrayAUCPr([0.1, 0.4, 0.4, 0.35, 0.8], [0, 0, 1, 1, 1]), 10); -select floor(arrayAUCPr([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]), 10); -select floor(arrayAUCPr([0.1, 0.35, 0.4, 0.4, 0.8], [1, 0, 1, 0, 0]), 10); -select floor(arrayAUCPr([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]), 10); -select floor(arrayAUCPr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 0, 1, 0, 0, 0, 1, 0, 0, 1]), 10); -select floor(arrayAUCPr([0, 1, 1, 2, 2, 2, 3, 3, 3, 3], [1, 0, 1, 0, 0, 0, 1, 0, 0, 1]), 10); - --- edge cases -SELECT floor(arrayAUCPr([1], [1]), 10); -SELECT floor(arrayAUCPr([1], [0]), 10); -SELECT floor(arrayAUCPr([0], [0]), 10); -SELECT floor(arrayAUCPr([0], [1]), 10); -SELECT floor(arrayAUCPr([1, 1], [1, 1]), 10); -SELECT floor(arrayAUCPr([1, 1], [0, 0]), 10); -SELECT floor(arrayAUCPr([1, 1], [0, 1]), 10); -SELECT floor(arrayAUCPr([0, 1], [0, 1]), 10); -SELECT floor(arrayAUCPr([1, 0], [0, 1]), 10); -SELECT floor(arrayAUCPr([0, 0, 1], [0, 1, 1]), 10); -SELECT floor(arrayAUCPr([0, 1, 1], [0, 1, 1]), 10); -SELECT floor(arrayAUCPr([0, 1, 1], [0, 0, 1]), 10); - --- negative tests -select arrayAUCPr([], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -select arrayAUCPr([0, 0, 1, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } -select arrayAUCPr([0.1, 0.35], [0, 0, 1, 1]); -- { serverError BAD_ARGUMENTS } -select arrayAUCPr([0.1, 0.4, 0.35, 0.8], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -select arrayAUCPr([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], [1, 1, 0, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } -select arrayAUCPr(['a', 'b', 'c', 'd'], [1, 0, 1, 1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -select arrayAUCPr([0.1, 0.4, NULL, 0.8], [0, 0, 1, 1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -select arrayAUCPr([0.1, 0.4, 0.35, 0.8], [0, NULL, 1, 1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } diff --git a/utils/check-style/aspell-ignore/en/aspell-dict.txt b/utils/check-style/aspell-ignore/en/aspell-dict.txt index f34ab458d00..dd26977ece1 100644 --- a/utils/check-style/aspell-ignore/en/aspell-dict.txt +++ b/utils/check-style/aspell-ignore/en/aspell-dict.txt @@ -1229,7 +1229,7 @@ argMin argmax argmin arrayAUC -arrayAUCPr +arrayAUCPR arrayAll arrayAvg arrayCompact @@ -1272,6 +1272,7 @@ arrayPopFront arrayProduct arrayPushBack arrayPushFront +arrayROCAUC arrayRandomSample arrayReduce arrayReduceInRanges From eac10514e3447039451daddce9698cb089cf424e Mon Sep 17 00:00:00 2001 From: Robert Schulze Date: Sun, 15 Dec 2024 15:41:52 +0000 Subject: [PATCH 7/8] Fix fasttest --- .../02415_all_new_functions_must_be_documented.reference | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference b/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference index 4365a994b7a..c8fa24c3cc2 100644 --- a/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference +++ b/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference @@ -90,8 +90,7 @@ alphaTokens and appendTrailingCharIfAbsent array -arrayAUC -arrayAUCPr +arrayAUCPR arrayAll arrayAvg arrayCompact @@ -128,6 +127,7 @@ arrayPopFront arrayProduct arrayPushBack arrayPushFront +arrayROCAUC arrayRandomSample arrayReduce arrayReduceInRanges From b222382ecc73d5c50ace5fdf11b24da649e8f5e8 Mon Sep 17 00:00:00 2001 From: Robert Schulze Date: Sun, 15 Dec 2024 17:53:40 +0000 Subject: [PATCH 8/8] Add arrayPRAUC alias --- docs/en/sql-reference/functions/array-functions.md | 2 ++ src/Functions/array/arrayAUC.cpp | 3 ++- tests/queries/0_stateless/03272_arrayAUCPR.reference | 1 + tests/queries/0_stateless/03272_arrayAUCPR.sql | 3 +++ 4 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/en/sql-reference/functions/array-functions.md b/docs/en/sql-reference/functions/array-functions.md index a93352dbe9e..498f01db938 100644 --- a/docs/en/sql-reference/functions/array-functions.md +++ b/docs/en/sql-reference/functions/array-functions.md @@ -2195,6 +2195,8 @@ For more details, please see [here](https://developers.google.com/machine-learni arrayAUCPR(arr_scores, arr_labels) ``` +Alias: `arrayPRAUC` + **Arguments** - `arr_scores` — scores prediction model gives. diff --git a/src/Functions/array/arrayAUC.cpp b/src/Functions/array/arrayAUC.cpp index c076addd9c5..efcec71930e 100644 --- a/src/Functions/array/arrayAUC.cpp +++ b/src/Functions/array/arrayAUC.cpp @@ -94,7 +94,7 @@ namespace ErrorCodes * For the curve we can calculate, literally, Area Under the Curve, that will be in the range of [0..1]. * * Let's look at the example: - * arrayPrAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); + * arrayAUCPR([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); * * 1. We have pairs: (-, 0.1), (-, 0.4), (+, 0.35), (+, 0.8) * @@ -347,6 +347,7 @@ REGISTER_FUNCTION(ArrayAUC) factory.registerFunction>(); factory.registerFunction>(); factory.registerAlias("arrayAUC", "arrayROCAUC"); /// Backward compatibility, also ROC AUC is often shorted to just AUC + factory.registerAlias("arrayPRAUC", "arrayAUCPR"); } } diff --git a/tests/queries/0_stateless/03272_arrayAUCPR.reference b/tests/queries/0_stateless/03272_arrayAUCPR.reference index ee37e657593..6e6466b7ff7 100644 --- a/tests/queries/0_stateless/03272_arrayAUCPR.reference +++ b/tests/queries/0_stateless/03272_arrayAUCPR.reference @@ -31,3 +31,4 @@ 0.8333333333 1 0.5 +1 diff --git a/tests/queries/0_stateless/03272_arrayAUCPR.sql b/tests/queries/0_stateless/03272_arrayAUCPR.sql index e89210dbe7a..6d86886a674 100644 --- a/tests/queries/0_stateless/03272_arrayAUCPR.sql +++ b/tests/queries/0_stateless/03272_arrayAUCPR.sql @@ -46,3 +46,6 @@ select arrayAUCPR([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], [1, 1, 0, 1]); -- { serve select arrayAUCPR(['a', 'b', 'c', 'd'], [1, 0, 1, 1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } select arrayAUCPR([0.1, 0.4, NULL, 0.8], [0, 0, 1, 1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } select arrayAUCPR([0.1, 0.4, 0.35, 0.8], [0, NULL, 1, 1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } + +--alias +SELECT floor(arrayPRAUC([1], [1]), 10);