From 2218ebebbfaac4d5114a476eafe40f71bdd25455 Mon Sep 17 00:00:00 2001 From: Gabriel Mendes Date: Wed, 18 Sep 2024 05:15:57 -0300 Subject: [PATCH 01/11] initial commit, tested function --- .../functions/array-functions.md | 35 +++ .../functions/array-functions.md | 37 +++ .../functions/array-functions.md | 35 +++ src/Functions/array/arrayAUCUnscaled.cpp | 212 ++++++++++++++++++ tests/fuzz/all.dict | 1 + tests/fuzz/dictionaries/functions.dict | 1 + tests/fuzz/dictionaries/old.dict | 1 + .../03237_array_auc_unscaled.reference | 25 +++ .../0_stateless/03237_array_auc_unscaled.sql | 30 +++ .../aspell-ignore/en/aspell-dict.txt | 1 + 10 files changed, 378 insertions(+) create mode 100644 src/Functions/array/arrayAUCUnscaled.cpp create mode 100644 tests/queries/0_stateless/03237_array_auc_unscaled.reference create mode 100644 tests/queries/0_stateless/03237_array_auc_unscaled.sql diff --git a/docs/en/sql-reference/functions/array-functions.md b/docs/en/sql-reference/functions/array-functions.md index ad971ae7554..89178cf8c5c 100644 --- a/docs/en/sql-reference/functions/array-functions.md +++ b/docs/en/sql-reference/functions/array-functions.md @@ -2116,6 +2116,41 @@ Result: └───────────────────────────────────────────────┘ ``` +## arrayAUC + +Calculate unscaled AUC (Area Under the Curve, which is a concept in machine learning, see more details: ), i.e. without dividing it by total true positives and total false positives. + +**Syntax** + +``` sql +arrayAUCUnscaled(arr_scores, arr_labels) +``` + +**Arguments** + +- `arr_scores` — scores prediction model gives. +- `arr_labels` — labels of samples, usually 1 for positive sample and 0 for negative sample. + +**Returned value** + +Returns unscaled AUC value with type Float64. + +**Example** + +Query: + +``` sql +select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); +``` + +Result: + +``` text +┌─arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐ +│ 3.0 │ +└───────────────────────────────────────────────────────┘ +``` + ## arrayMap(func, arr1, ...) Returns an array obtained from the original arrays by application of `func(arr1[i], ..., arrN[i])` for each element. Arrays `arr1` ... `arrN` must have the same number of elements. diff --git a/docs/ru/sql-reference/functions/array-functions.md b/docs/ru/sql-reference/functions/array-functions.md index 825e3f06be2..7923e9af945 100644 --- a/docs/ru/sql-reference/functions/array-functions.md +++ b/docs/ru/sql-reference/functions/array-functions.md @@ -1654,6 +1654,43 @@ SELECT arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); └────────────────────────────────────────---──┘ ``` +## arrayAUCUnscaled {#arrayaucunscaled} + +Вычисляет площадь под кривой без нормализации. + +**Синтаксис** + +``` sql +arrayAUCUnscaled(arr_scores, arr_labels) +``` + +**Аргументы** + +- `arr_scores` — оценка, которую дает модель предсказания. +- `arr_labels` — ярлыки выборок, обычно 1 для содержательных выборок и 0 для бессодержательных выборок. + +**Возвращаемое значение** + +Значение площади под кривой без нормализации. + +Тип данных: `Float64`. + +**Пример** + +Запрос: + +``` sql +SELECT arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); +``` + +Результат: + +``` text +┌─arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐ +│ 3.0 │ +└────────────────────────────────────────---────────────┘ +``` + ## arrayProduct {#arrayproduct} Возвращает произведение элементов [массива](../../sql-reference/data-types/array.md). diff --git a/docs/zh/sql-reference/functions/array-functions.md b/docs/zh/sql-reference/functions/array-functions.md index 69db34e4a36..5ff3e6a424c 100644 --- a/docs/zh/sql-reference/functions/array-functions.md +++ b/docs/zh/sql-reference/functions/array-functions.md @@ -1221,6 +1221,41 @@ select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); └───────────────────────────────────────────────┘ ``` +## arrayAUCUnscaled {#arrayaucunscaled} + +计算没有归一化的AUC (ROC曲线下的面积,这是机器学习中的一个概念,更多细节请查看:https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve)。 + +**语法** + +``` sql +arrayAUCUnscaled(arr_scores, arr_labels) +``` + +**参数** + +- `arr_scores` — 分数预测模型给出。 +- `arr_labels` — 样本的标签,通常为 1 表示正样本,0 表示负样本。 + +**返回值** + +返回 Float64 类型的非标准化 AUC 值。 + +**示例** + +查询语句: + +``` sql +select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); +``` + +结果: + +``` text +┌─arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐ +│ 3.0 │ +└───────────────────────────────────────────────────────┘ +``` + ## arrayMap(func, arr1, ...) {#array-map} 将从 `func` 函数的原始应用中获得的数组返回给 `arr` 数组中的每个元素。 diff --git a/src/Functions/array/arrayAUCUnscaled.cpp b/src/Functions/array/arrayAUCUnscaled.cpp new file mode 100644 index 00000000000..2cf0d072218 --- /dev/null +++ b/src/Functions/array/arrayAUCUnscaled.cpp @@ -0,0 +1,212 @@ +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int ILLEGAL_COLUMN; + extern const int BAD_ARGUMENTS; +} + + +/** The function takes two arrays: scores and labels. + * Label can be one of two values: positive and negative. + * Score can be arbitrary number. + * + * These values are considered as the output of classifier. We have some true labels for objects. + * And classifier assigns some scores to objects that predict these labels in the following way: + * - we can define arbitrary threshold on score and predict that the label is positive if the score is greater than the threshold: + * + * f(object) = score + * predicted_label = score > threshold + * + * This way classifier may predict positive or negative value correctly - true positive or true negative + * or have false positive or false negative result. + * Verying the threshold we can get different probabilities of false positive or false negatives or true positives, etc... + * + * We can also calculate the True Positive Rate and the False Positive Rate: + * + * TPR (also called "sensitivity", "recall" or "probability of detection") + * is the probability of classifier to give positive result if the object has positive label: + * TPR = P(score > threshold | label = positive) + * + * FPR is the probability of classifier to give positive result if the object has negative label: + * FPR = P(score > threshold | label = negative) + * + * We can draw a curve of values of FPR and TPR with different threshold on [0..1] x [0..1] unit square. + * This curve is named "ROC curve" (Receiver Operating Characteristic). + * + * For ROC we can calculate, literally, Area Under the Curve, that will be in the range of [0..1]. + * The higher the AUC the better the classifier. + * + * AUC also is as the probability that the score for positive label is greater than the score for negative label. + * + * https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc + * https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve + * + * To calculate AUC, we will draw points of (FPR, TPR) for different thresholds = score_i. + * FPR_raw = countIf(score > score_i, label = negative) = count negative labels above certain score + * TPR_raw = countIf(score > score_i, label = positive) = count positive labels above certain score + * + * Let's look at the example: + * arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); + * + * 1. We have pairs: (-, 0.1), (-, 0.4), (+, 0.35), (+, 0.8) + * + * 2. Let's sort by score: (-, 0.1), (+, 0.35), (-, 0.4), (+, 0.8) + * + * 3. Let's draw the points: + * + * threshold = 0, TPR_raw = 2, FPR_raw = 2 + * threshold = 0.1, TPR_raw = 2, FPR_raw = 1 + * threshold = 0.35, TPR_raw = 1, FPR_raw = 1 + * threshold = 0.4, TPR_raw = 1, FPR_raw = 0 + * threshold = 0.8, TPR_raw = 0, FPR_raw = 0 + * + * The "curve" will be present by a line that moves one step either towards right or top on each threshold change. + */ + +class FunctionArrayAUCUnscaled : public IFunction +{ +public: + static constexpr auto name = "arrayAUCUnscaled"; + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + +private: + static Float64 apply( + const IColumn & scores, + const IColumn & labels, + ColumnArray::Offset current_offset, + ColumnArray::Offset next_offset) + { + struct ScoreLabel + { + Float64 score; + bool label; + }; + + size_t size = next_offset - current_offset; + PODArrayWithStackMemory sorted_labels(size); + + for (size_t i = 0; i < size; ++i) + { + bool label = labels.getFloat64(current_offset + i) > 0; + sorted_labels[i].score = scores.getFloat64(current_offset + i); + sorted_labels[i].label = label; + } + + /// Sorting scores in descending order to traverse the ROC curve from left to right + std::sort(sorted_labels.begin(), sorted_labels.end(), [](const auto & lhs, const auto & rhs) { return lhs.score > rhs.score; }); + + Float64 area = 0.0; + Float64 prev_score = sorted_labels[0].score; + size_t prev_fp = 0, prev_tp = 0; + size_t curr_fp = 0, curr_tp = 0; + for (size_t i = 0; i < size; ++i) + { + // Only increment the area when the score changes + if (sorted_labels[i].score != prev_score) + { + area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; // Trapezoidal area under curve (might degenerate to zero or to a rectangle) + prev_fp = curr_fp; + prev_tp = curr_tp; + prev_score = sorted_labels[i].score; + } + + if (sorted_labels[i].label) + curr_tp += 1; /// The curve moves one step up. + else + curr_fp += 1; /// The curve moves one step right. + } + + area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; + + return area; + } + + static void vector( + const IColumn & scores, + const IColumn & labels, + const ColumnArray::Offsets & offsets, + PaddedPODArray & result, + size_t input_rows_count) + { + result.resize(input_rows_count); + + ColumnArray::Offset current_offset = 0; + for (size_t i = 0; i < input_rows_count; ++i) + { + auto next_offset = offsets[i]; + result[i] = apply(scores, labels, current_offset, next_offset); + current_offset = next_offset; + } + } + +public: + String getName() const override { return name; } + size_t getNumberOfArguments() const override { return 2; } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return false; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + for (size_t i = 0; i < getNumberOfArguments(); ++i) + { + const DataTypeArray * array_type = checkAndGetDataType(arguments[i].get()); + if (!array_type) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "All arguments for function {} must be an array.", getName()); + + const auto & nested_type = array_type->getNestedType(); + if (!isNativeNumber(nested_type) && !isEnum(nested_type)) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{} cannot process values of type {}", + getName(), nested_type->getName()); + } + + return std::make_shared(); + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + ColumnPtr col1 = arguments[0].column->convertToFullColumnIfConst(); + ColumnPtr col2 = arguments[1].column->convertToFullColumnIfConst(); + + const ColumnArray * col_array1 = checkAndGetColumn(col1.get()); + if (!col_array1) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, + "Illegal column {} of first argument of function {}", arguments[0].column->getName(), getName()); + + const ColumnArray * col_array2 = checkAndGetColumn(col2.get()); + if (!col_array2) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, + "Illegal column {} of second argument of function {}", arguments[1].column->getName(), getName()); + + if (!col_array1->hasEqualOffsets(*col_array2)) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Array arguments for function {} must have equal sizes", getName()); + + auto col_res = ColumnVector::create(); + + vector( + col_array1->getData(), + col_array2->getData(), + col_array1->getOffsets(), + col_res->getData(), + input_rows_count); + + return col_res; + } +}; + + +REGISTER_FUNCTION(ArrayAUCUnscaled) +{ + factory.registerFunction(); +} + +} diff --git a/tests/fuzz/all.dict b/tests/fuzz/all.dict index 30af3746fca..6cb198a3e48 100644 --- a/tests/fuzz/all.dict +++ b/tests/fuzz/all.dict @@ -1216,6 +1216,7 @@ "argMinState" "array" "arrayAUC" +"arrayAUCUnscaled" "arrayAll" "arrayAvg" "arrayCompact" diff --git a/tests/fuzz/dictionaries/functions.dict b/tests/fuzz/dictionaries/functions.dict index e562595fb67..302aab97c2d 100644 --- a/tests/fuzz/dictionaries/functions.dict +++ b/tests/fuzz/dictionaries/functions.dict @@ -529,6 +529,7 @@ "argMinState" "array" "arrayAUC" +"arrayAUCUnscaled" "arrayAll" "arrayAvg" "arrayCompact" diff --git a/tests/fuzz/dictionaries/old.dict b/tests/fuzz/dictionaries/old.dict index 61914c3b283..6ecb5503ca4 100644 --- a/tests/fuzz/dictionaries/old.dict +++ b/tests/fuzz/dictionaries/old.dict @@ -19,6 +19,7 @@ "Array" "arrayAll" "arrayAUC" +"arrayAUCUnscaled" "arrayCompact" "arrayConcat" "arrayCount" diff --git a/tests/queries/0_stateless/03237_array_auc_unscaled.reference b/tests/queries/0_stateless/03237_array_auc_unscaled.reference new file mode 100644 index 00000000000..63204682fd4 --- /dev/null +++ b/tests/queries/0_stateless/03237_array_auc_unscaled.reference @@ -0,0 +1,25 @@ +3 +3 +3 +3 +3 +3 +3 +3 +3 +1 +1 +1 +1 +1 +1 +1 +0 +0 +0 +0.5 +1 +0 +1.5 +2 +1.5 diff --git a/tests/queries/0_stateless/03237_array_auc_unscaled.sql b/tests/queries/0_stateless/03237_array_auc_unscaled.sql new file mode 100644 index 00000000000..d4f07c42118 --- /dev/null +++ b/tests/queries/0_stateless/03237_array_auc_unscaled.sql @@ -0,0 +1,30 @@ +select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); +select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8))); +select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8))); +select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1)))); +select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1)))); +select arrayAUCUnscaled(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1]); +select arrayAUCUnscaled(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1]); +select arrayAUCUnscaled(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1]); +select arrayAUCUnscaled(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1]); +select arrayAUCUnscaled(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1]); +select arrayAUCUnscaled(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1]); +select arrayAUCUnscaled(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1]); +select arrayAUCUnscaled(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]); +select arrayAUCUnscaled(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]); +select arrayAUCUnscaled([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]); +select arrayAUCUnscaled([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]); +SELECT arrayAUCUnscaled([], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayAUCUnscaled([1], [1]); +SELECT arrayAUCUnscaled([1], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayAUCUnscaled([], [1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayAUCUnscaled([1, 2], [3]); -- { serverError BAD_ARGUMENTS } +SELECT arrayAUCUnscaled([1], [2, 3]); -- { serverError BAD_ARGUMENTS } +SELECT arrayAUCUnscaled([1, 1], [1, 1]); +SELECT arrayAUCUnscaled([1, 1], [0, 0]); +SELECT arrayAUCUnscaled([1, 1], [0, 1]); +SELECT arrayAUCUnscaled([0, 1], [0, 1]); +SELECT arrayAUCUnscaled([1, 0], [0, 1]); +SELECT arrayAUCUnscaled([0, 0, 1], [0, 1, 1]); +SELECT arrayAUCUnscaled([0, 1, 1], [0, 1, 1]); +SELECT arrayAUCUnscaled([0, 1, 1], [0, 0, 1]); diff --git a/utils/check-style/aspell-ignore/en/aspell-dict.txt b/utils/check-style/aspell-ignore/en/aspell-dict.txt index 3467f21c812..f658b19e8a7 100644 --- a/utils/check-style/aspell-ignore/en/aspell-dict.txt +++ b/utils/check-style/aspell-ignore/en/aspell-dict.txt @@ -1153,6 +1153,7 @@ argMin argmax argmin arrayAUC +arrayAUCUnscaled arrayAll arrayAvg arrayCompact From 8f350a7ec931478b64d76501715c5cd4be5d4af3 Mon Sep 17 00:00:00 2001 From: Gabriel Mendes Date: Wed, 18 Sep 2024 08:52:58 -0300 Subject: [PATCH 02/11] remove separate function --- .../functions/array-functions.md | 38 +--- .../functions/array-functions.md | 37 --- .../functions/array-functions.md | 35 --- src/Functions/array/arrayAUC.cpp | 39 +++- src/Functions/array/arrayAUCUnscaled.cpp | 212 ------------------ tests/fuzz/all.dict | 1 - tests/fuzz/dictionaries/functions.dict | 1 - tests/fuzz/dictionaries/old.dict | 1 - .../0_stateless/03237_array_auc_unscaled.sql | 60 ++--- .../03237_array_auc_unscaled.stdout-e | 25 +++ .../aspell-ignore/en/aspell-dict.txt | 1 - 11 files changed, 86 insertions(+), 364 deletions(-) delete mode 100644 src/Functions/array/arrayAUCUnscaled.cpp create mode 100644 tests/queries/0_stateless/03237_array_auc_unscaled.stdout-e diff --git a/docs/en/sql-reference/functions/array-functions.md b/docs/en/sql-reference/functions/array-functions.md index 89178cf8c5c..84396602d26 100644 --- a/docs/en/sql-reference/functions/array-functions.md +++ b/docs/en/sql-reference/functions/array-functions.md @@ -2088,13 +2088,14 @@ Calculate AUC (Area Under the Curve, which is a concept in machine learning, see **Syntax** ``` sql -arrayAUC(arr_scores, arr_labels) +arrayAUC(arr_scores, arr_labels[, scale]) ``` **Arguments** - `arr_scores` — scores prediction model gives. - `arr_labels` — labels of samples, usually 1 for positive sample and 0 for negative sample. +- `scale` - Optional. Wether to return the normalized area. Default value: true. [Bool] **Returned value** @@ -2116,41 +2117,6 @@ Result: └───────────────────────────────────────────────┘ ``` -## arrayAUC - -Calculate unscaled AUC (Area Under the Curve, which is a concept in machine learning, see more details: ), i.e. without dividing it by total true positives and total false positives. - -**Syntax** - -``` sql -arrayAUCUnscaled(arr_scores, arr_labels) -``` - -**Arguments** - -- `arr_scores` — scores prediction model gives. -- `arr_labels` — labels of samples, usually 1 for positive sample and 0 for negative sample. - -**Returned value** - -Returns unscaled AUC value with type Float64. - -**Example** - -Query: - -``` sql -select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); -``` - -Result: - -``` text -┌─arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐ -│ 3.0 │ -└───────────────────────────────────────────────────────┘ -``` - ## arrayMap(func, arr1, ...) Returns an array obtained from the original arrays by application of `func(arr1[i], ..., arrN[i])` for each element. Arrays `arr1` ... `arrN` must have the same number of elements. diff --git a/docs/ru/sql-reference/functions/array-functions.md b/docs/ru/sql-reference/functions/array-functions.md index 7923e9af945..825e3f06be2 100644 --- a/docs/ru/sql-reference/functions/array-functions.md +++ b/docs/ru/sql-reference/functions/array-functions.md @@ -1654,43 +1654,6 @@ SELECT arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); └────────────────────────────────────────---──┘ ``` -## arrayAUCUnscaled {#arrayaucunscaled} - -Вычисляет площадь под кривой без нормализации. - -**Синтаксис** - -``` sql -arrayAUCUnscaled(arr_scores, arr_labels) -``` - -**Аргументы** - -- `arr_scores` — оценка, которую дает модель предсказания. -- `arr_labels` — ярлыки выборок, обычно 1 для содержательных выборок и 0 для бессодержательных выборок. - -**Возвращаемое значение** - -Значение площади под кривой без нормализации. - -Тип данных: `Float64`. - -**Пример** - -Запрос: - -``` sql -SELECT arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); -``` - -Результат: - -``` text -┌─arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐ -│ 3.0 │ -└────────────────────────────────────────---────────────┘ -``` - ## arrayProduct {#arrayproduct} Возвращает произведение элементов [массива](../../sql-reference/data-types/array.md). diff --git a/docs/zh/sql-reference/functions/array-functions.md b/docs/zh/sql-reference/functions/array-functions.md index 5ff3e6a424c..69db34e4a36 100644 --- a/docs/zh/sql-reference/functions/array-functions.md +++ b/docs/zh/sql-reference/functions/array-functions.md @@ -1221,41 +1221,6 @@ select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); └───────────────────────────────────────────────┘ ``` -## arrayAUCUnscaled {#arrayaucunscaled} - -计算没有归一化的AUC (ROC曲线下的面积,这是机器学习中的一个概念,更多细节请查看:https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve)。 - -**语法** - -``` sql -arrayAUCUnscaled(arr_scores, arr_labels) -``` - -**参数** - -- `arr_scores` — 分数预测模型给出。 -- `arr_labels` — 样本的标签,通常为 1 表示正样本,0 表示负样本。 - -**返回值** - -返回 Float64 类型的非标准化 AUC 值。 - -**示例** - -查询语句: - -``` sql -select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); -``` - -结果: - -``` text -┌─arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐ -│ 3.0 │ -└───────────────────────────────────────────────────────┘ -``` - ## arrayMap(func, arr1, ...) {#array-map} 将从 `func` 函数的原始应用中获得的数组返回给 `arr` 数组中的每个元素。 diff --git a/src/Functions/array/arrayAUC.cpp b/src/Functions/array/arrayAUC.cpp index 7a61c9d368f..adaa52818d0 100644 --- a/src/Functions/array/arrayAUC.cpp +++ b/src/Functions/array/arrayAUC.cpp @@ -85,7 +85,8 @@ private: const IColumn & scores, const IColumn & labels, ColumnArray::Offset current_offset, - ColumnArray::Offset next_offset) + ColumnArray::Offset next_offset, + bool scale = true) { struct ScoreLabel { @@ -131,12 +132,15 @@ private: area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; - /// Then normalize it dividing by the area to the area of rectangle. + /// Then normalize it, if scale is true, dividing by the area to the area of rectangle. - if (curr_tp == 0 || curr_tp == size) + if (scale && (curr_tp == 0 || curr_tp == size)) return std::numeric_limits::quiet_NaN(); - return area / curr_tp / (size - curr_tp); + if (scale) + return area / curr_tp / (size - curr_tp); + else + return area; } static void vector( @@ -144,7 +148,8 @@ private: const IColumn & labels, const ColumnArray::Offsets & offsets, PaddedPODArray & result, - size_t input_rows_count) + size_t input_rows_count, + bool scale = true) { result.resize(input_rows_count); @@ -152,23 +157,23 @@ private: for (size_t i = 0; i < input_rows_count; ++i) { auto next_offset = offsets[i]; - result[i] = apply(scores, labels, current_offset, next_offset); + result[i] = apply(scores, labels, current_offset, next_offset, scale); current_offset = next_offset; } } public: String getName() const override { return name; } - size_t getNumberOfArguments() const override { return 2; } + size_t getNumberOfArguments() const override { return 3; } bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return false; } DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { - for (size_t i = 0; i < getNumberOfArguments(); ++i) + for (size_t i = 0; i < 2; ++i) { const DataTypeArray * array_type = checkAndGetDataType(arguments[i].get()); if (!array_type) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "All arguments for function {} must be an array.", getName()); + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The two first arguments for function {} must be an array.", getName()); const auto & nested_type = array_type->getNestedType(); if (!isNativeNumber(nested_type) && !isEnum(nested_type)) @@ -176,6 +181,12 @@ public: getName(), nested_type->getName()); } + if (arguments.size() == 3) + { + if (!isBool(arguments[2])) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument must be a boolean (scale)"); + } + return std::make_shared(); } @@ -197,6 +208,13 @@ public: if (!col_array1->hasEqualOffsets(*col_array2)) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Array arguments for function {} must have equal sizes", getName()); + // Handle third argument for scale (if passed, otherwise default to true) + bool scale = true; + if (arguments.size() == 3) + { + scale = arguments[2].column->getBool(0); // Assumes it's a scalar boolean column + } + auto col_res = ColumnVector::create(); vector( @@ -204,7 +222,8 @@ public: col_array2->getData(), col_array1->getOffsets(), col_res->getData(), - input_rows_count); + input_rows_count, + scale); return col_res; } diff --git a/src/Functions/array/arrayAUCUnscaled.cpp b/src/Functions/array/arrayAUCUnscaled.cpp deleted file mode 100644 index 2cf0d072218..00000000000 --- a/src/Functions/array/arrayAUCUnscaled.cpp +++ /dev/null @@ -1,212 +0,0 @@ -#include -#include -#include -#include -#include -#include - - -namespace DB -{ - -namespace ErrorCodes -{ - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int ILLEGAL_COLUMN; - extern const int BAD_ARGUMENTS; -} - - -/** The function takes two arrays: scores and labels. - * Label can be one of two values: positive and negative. - * Score can be arbitrary number. - * - * These values are considered as the output of classifier. We have some true labels for objects. - * And classifier assigns some scores to objects that predict these labels in the following way: - * - we can define arbitrary threshold on score and predict that the label is positive if the score is greater than the threshold: - * - * f(object) = score - * predicted_label = score > threshold - * - * This way classifier may predict positive or negative value correctly - true positive or true negative - * or have false positive or false negative result. - * Verying the threshold we can get different probabilities of false positive or false negatives or true positives, etc... - * - * We can also calculate the True Positive Rate and the False Positive Rate: - * - * TPR (also called "sensitivity", "recall" or "probability of detection") - * is the probability of classifier to give positive result if the object has positive label: - * TPR = P(score > threshold | label = positive) - * - * FPR is the probability of classifier to give positive result if the object has negative label: - * FPR = P(score > threshold | label = negative) - * - * We can draw a curve of values of FPR and TPR with different threshold on [0..1] x [0..1] unit square. - * This curve is named "ROC curve" (Receiver Operating Characteristic). - * - * For ROC we can calculate, literally, Area Under the Curve, that will be in the range of [0..1]. - * The higher the AUC the better the classifier. - * - * AUC also is as the probability that the score for positive label is greater than the score for negative label. - * - * https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc - * https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve - * - * To calculate AUC, we will draw points of (FPR, TPR) for different thresholds = score_i. - * FPR_raw = countIf(score > score_i, label = negative) = count negative labels above certain score - * TPR_raw = countIf(score > score_i, label = positive) = count positive labels above certain score - * - * Let's look at the example: - * arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); - * - * 1. We have pairs: (-, 0.1), (-, 0.4), (+, 0.35), (+, 0.8) - * - * 2. Let's sort by score: (-, 0.1), (+, 0.35), (-, 0.4), (+, 0.8) - * - * 3. Let's draw the points: - * - * threshold = 0, TPR_raw = 2, FPR_raw = 2 - * threshold = 0.1, TPR_raw = 2, FPR_raw = 1 - * threshold = 0.35, TPR_raw = 1, FPR_raw = 1 - * threshold = 0.4, TPR_raw = 1, FPR_raw = 0 - * threshold = 0.8, TPR_raw = 0, FPR_raw = 0 - * - * The "curve" will be present by a line that moves one step either towards right or top on each threshold change. - */ - -class FunctionArrayAUCUnscaled : public IFunction -{ -public: - static constexpr auto name = "arrayAUCUnscaled"; - static FunctionPtr create(ContextPtr) { return std::make_shared(); } - -private: - static Float64 apply( - const IColumn & scores, - const IColumn & labels, - ColumnArray::Offset current_offset, - ColumnArray::Offset next_offset) - { - struct ScoreLabel - { - Float64 score; - bool label; - }; - - size_t size = next_offset - current_offset; - PODArrayWithStackMemory sorted_labels(size); - - for (size_t i = 0; i < size; ++i) - { - bool label = labels.getFloat64(current_offset + i) > 0; - sorted_labels[i].score = scores.getFloat64(current_offset + i); - sorted_labels[i].label = label; - } - - /// Sorting scores in descending order to traverse the ROC curve from left to right - std::sort(sorted_labels.begin(), sorted_labels.end(), [](const auto & lhs, const auto & rhs) { return lhs.score > rhs.score; }); - - Float64 area = 0.0; - Float64 prev_score = sorted_labels[0].score; - size_t prev_fp = 0, prev_tp = 0; - size_t curr_fp = 0, curr_tp = 0; - for (size_t i = 0; i < size; ++i) - { - // Only increment the area when the score changes - if (sorted_labels[i].score != prev_score) - { - area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; // Trapezoidal area under curve (might degenerate to zero or to a rectangle) - prev_fp = curr_fp; - prev_tp = curr_tp; - prev_score = sorted_labels[i].score; - } - - if (sorted_labels[i].label) - curr_tp += 1; /// The curve moves one step up. - else - curr_fp += 1; /// The curve moves one step right. - } - - area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; - - return area; - } - - static void vector( - const IColumn & scores, - const IColumn & labels, - const ColumnArray::Offsets & offsets, - PaddedPODArray & result, - size_t input_rows_count) - { - result.resize(input_rows_count); - - ColumnArray::Offset current_offset = 0; - for (size_t i = 0; i < input_rows_count; ++i) - { - auto next_offset = offsets[i]; - result[i] = apply(scores, labels, current_offset, next_offset); - current_offset = next_offset; - } - } - -public: - String getName() const override { return name; } - size_t getNumberOfArguments() const override { return 2; } - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return false; } - - DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override - { - for (size_t i = 0; i < getNumberOfArguments(); ++i) - { - const DataTypeArray * array_type = checkAndGetDataType(arguments[i].get()); - if (!array_type) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "All arguments for function {} must be an array.", getName()); - - const auto & nested_type = array_type->getNestedType(); - if (!isNativeNumber(nested_type) && !isEnum(nested_type)) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{} cannot process values of type {}", - getName(), nested_type->getName()); - } - - return std::make_shared(); - } - - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override - { - ColumnPtr col1 = arguments[0].column->convertToFullColumnIfConst(); - ColumnPtr col2 = arguments[1].column->convertToFullColumnIfConst(); - - const ColumnArray * col_array1 = checkAndGetColumn(col1.get()); - if (!col_array1) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, - "Illegal column {} of first argument of function {}", arguments[0].column->getName(), getName()); - - const ColumnArray * col_array2 = checkAndGetColumn(col2.get()); - if (!col_array2) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, - "Illegal column {} of second argument of function {}", arguments[1].column->getName(), getName()); - - if (!col_array1->hasEqualOffsets(*col_array2)) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "Array arguments for function {} must have equal sizes", getName()); - - auto col_res = ColumnVector::create(); - - vector( - col_array1->getData(), - col_array2->getData(), - col_array1->getOffsets(), - col_res->getData(), - input_rows_count); - - return col_res; - } -}; - - -REGISTER_FUNCTION(ArrayAUCUnscaled) -{ - factory.registerFunction(); -} - -} diff --git a/tests/fuzz/all.dict b/tests/fuzz/all.dict index 6cb198a3e48..30af3746fca 100644 --- a/tests/fuzz/all.dict +++ b/tests/fuzz/all.dict @@ -1216,7 +1216,6 @@ "argMinState" "array" "arrayAUC" -"arrayAUCUnscaled" "arrayAll" "arrayAvg" "arrayCompact" diff --git a/tests/fuzz/dictionaries/functions.dict b/tests/fuzz/dictionaries/functions.dict index 302aab97c2d..e562595fb67 100644 --- a/tests/fuzz/dictionaries/functions.dict +++ b/tests/fuzz/dictionaries/functions.dict @@ -529,7 +529,6 @@ "argMinState" "array" "arrayAUC" -"arrayAUCUnscaled" "arrayAll" "arrayAvg" "arrayCompact" diff --git a/tests/fuzz/dictionaries/old.dict b/tests/fuzz/dictionaries/old.dict index 6ecb5503ca4..61914c3b283 100644 --- a/tests/fuzz/dictionaries/old.dict +++ b/tests/fuzz/dictionaries/old.dict @@ -19,7 +19,6 @@ "Array" "arrayAll" "arrayAUC" -"arrayAUCUnscaled" "arrayCompact" "arrayConcat" "arrayCount" diff --git a/tests/queries/0_stateless/03237_array_auc_unscaled.sql b/tests/queries/0_stateless/03237_array_auc_unscaled.sql index d4f07c42118..4083e836067 100644 --- a/tests/queries/0_stateless/03237_array_auc_unscaled.sql +++ b/tests/queries/0_stateless/03237_array_auc_unscaled.sql @@ -1,30 +1,30 @@ -select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]); -select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8))); -select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8))); -select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1)))); -select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1)))); -select arrayAUCUnscaled(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1]); -select arrayAUCUnscaled(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1]); -select arrayAUCUnscaled(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1]); -select arrayAUCUnscaled(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1]); -select arrayAUCUnscaled(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1]); -select arrayAUCUnscaled(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1]); -select arrayAUCUnscaled(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1]); -select arrayAUCUnscaled(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]); -select arrayAUCUnscaled(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]); -select arrayAUCUnscaled([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]); -select arrayAUCUnscaled([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]); -SELECT arrayAUCUnscaled([], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUCUnscaled([1], [1]); -SELECT arrayAUCUnscaled([1], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUCUnscaled([], [1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUCUnscaled([1, 2], [3]); -- { serverError BAD_ARGUMENTS } -SELECT arrayAUCUnscaled([1], [2, 3]); -- { serverError BAD_ARGUMENTS } -SELECT arrayAUCUnscaled([1, 1], [1, 1]); -SELECT arrayAUCUnscaled([1, 1], [0, 0]); -SELECT arrayAUCUnscaled([1, 1], [0, 1]); -SELECT arrayAUCUnscaled([0, 1], [0, 1]); -SELECT arrayAUCUnscaled([1, 0], [0, 1]); -SELECT arrayAUCUnscaled([0, 0, 1], [0, 1, 1]); -SELECT arrayAUCUnscaled([0, 1, 1], [0, 1, 1]); -SELECT arrayAUCUnscaled([0, 1, 1], [0, 0, 1]); +select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], false); +select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), false); +select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), false); +select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))), false); +select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))), false); +select arrayAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1], false); +select arrayAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1], false); +select arrayAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1], false); +select arrayAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1], false); +select arrayAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1], false); +select arrayAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1], false); +select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1], false); +select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1], false); +select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1], false); +select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0], false); +select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0], false); +SELECT arrayAUC([], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayAUC([1], [1], false); +SELECT arrayAUC([1], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayAUC([], [1], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayAUC([1, 2], [3], false); -- { serverError BAD_ARGUMENTS } +SELECT arrayAUC([1], [2, 3], false); -- { serverError BAD_ARGUMENTS } +SELECT arrayAUC([1, 1], [1, 1], false); +SELECT arrayAUC([1, 1], [0, 0], false); +SELECT arrayAUC([1, 1], [0, 1], false); +SELECT arrayAUC([0, 1], [0, 1], false); +SELECT arrayAUC([1, 0], [0, 1], false); +SELECT arrayAUC([0, 0, 1], [0, 1, 1], false); +SELECT arrayAUC([0, 1, 1], [0, 1, 1], false); +SELECT arrayAUC([0, 1, 1], [0, 0, 1], false); diff --git a/tests/queries/0_stateless/03237_array_auc_unscaled.stdout-e b/tests/queries/0_stateless/03237_array_auc_unscaled.stdout-e new file mode 100644 index 00000000000..e6b5161afe3 --- /dev/null +++ b/tests/queries/0_stateless/03237_array_auc_unscaled.stdout-e @@ -0,0 +1,25 @@ +3 +3 +3 +3 +3 +3 +3 +3 +3 +1 +1 +1 +1 +1 +1 +1 +0 +nan +0 +0.5 +1 +0 +1.5 +2 +1.5 diff --git a/utils/check-style/aspell-ignore/en/aspell-dict.txt b/utils/check-style/aspell-ignore/en/aspell-dict.txt index f658b19e8a7..3467f21c812 100644 --- a/utils/check-style/aspell-ignore/en/aspell-dict.txt +++ b/utils/check-style/aspell-ignore/en/aspell-dict.txt @@ -1153,7 +1153,6 @@ argMin argmax argmin arrayAUC -arrayAUCUnscaled arrayAll arrayAvg arrayCompact From 4c72fb0e32f4e872c897dcd413900d94a5d3fb27 Mon Sep 17 00:00:00 2001 From: Gabriel Mendes Date: Wed, 18 Sep 2024 08:56:13 -0300 Subject: [PATCH 03/11] remove unnecessary file --- .../03237_array_auc_unscaled.stdout-e | 25 ------------------- 1 file changed, 25 deletions(-) delete mode 100644 tests/queries/0_stateless/03237_array_auc_unscaled.stdout-e diff --git a/tests/queries/0_stateless/03237_array_auc_unscaled.stdout-e b/tests/queries/0_stateless/03237_array_auc_unscaled.stdout-e deleted file mode 100644 index e6b5161afe3..00000000000 --- a/tests/queries/0_stateless/03237_array_auc_unscaled.stdout-e +++ /dev/null @@ -1,25 +0,0 @@ -3 -3 -3 -3 -3 -3 -3 -3 -3 -1 -1 -1 -1 -1 -1 -1 -0 -nan -0 -0.5 -1 -0 -1.5 -2 -1.5 From 4be8a0feba0646de617a366b6b7f3411158ada19 Mon Sep 17 00:00:00 2001 From: Gabriel Mendes Date: Wed, 18 Sep 2024 08:58:14 -0300 Subject: [PATCH 04/11] fmt --- src/Functions/array/arrayAUC.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Functions/array/arrayAUC.cpp b/src/Functions/array/arrayAUC.cpp index adaa52818d0..999aa999015 100644 --- a/src/Functions/array/arrayAUC.cpp +++ b/src/Functions/array/arrayAUC.cpp @@ -140,7 +140,7 @@ private: if (scale) return area / curr_tp / (size - curr_tp); else - return area; + return area; } static void vector( @@ -212,7 +212,7 @@ public: bool scale = true; if (arguments.size() == 3) { - scale = arguments[2].column->getBool(0); // Assumes it's a scalar boolean column + scale = arguments[2].column->getBool(0); // Assumes it's a scalar boolean column } auto col_res = ColumnVector::create(); From e3b207d21702dc1b38ccbdb5b5b5cfa77ecf7617 Mon Sep 17 00:00:00 2001 From: Gabriel Mendes Date: Wed, 18 Sep 2024 09:03:29 -0300 Subject: [PATCH 05/11] fmt --- src/Functions/array/arrayAUC.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Functions/array/arrayAUC.cpp b/src/Functions/array/arrayAUC.cpp index 999aa999015..c770899214c 100644 --- a/src/Functions/array/arrayAUC.cpp +++ b/src/Functions/array/arrayAUC.cpp @@ -139,8 +139,8 @@ private: if (scale) return area / curr_tp / (size - curr_tp); - else - return area; + + return area; } static void vector( From b94017125215599ce7ac09ad577a2431d9d078d1 Mon Sep 17 00:00:00 2001 From: Gabriel Mendes Date: Wed, 18 Sep 2024 11:04:46 -0300 Subject: [PATCH 06/11] fix tests --- src/Functions/array/arrayAUC.cpp | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/Functions/array/arrayAUC.cpp b/src/Functions/array/arrayAUC.cpp index c770899214c..4f37617ab79 100644 --- a/src/Functions/array/arrayAUC.cpp +++ b/src/Functions/array/arrayAUC.cpp @@ -14,6 +14,7 @@ namespace ErrorCodes extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int ILLEGAL_COLUMN; extern const int BAD_ARGUMENTS; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } @@ -164,11 +165,21 @@ private: public: String getName() const override { return name; } - size_t getNumberOfArguments() const override { return 3; } - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return false; } + + bool isVariadic() const override { return true; } + size_t getNumberOfArguments() const override { return 0; } + + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { + const size_t number_of_arguments = arguments.size(); + + if (number_of_arguments < 2 || number_of_arguments > 3) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Number of arguments for function {} doesn't match: passed {}, should be 2 or 3", + getName(), number_of_arguments); + for (size_t i = 0; i < 2; ++i) { const DataTypeArray * array_type = checkAndGetDataType(arguments[i].get()); @@ -181,7 +192,7 @@ public: getName(), nested_type->getName()); } - if (arguments.size() == 3) + if (number_of_arguments == 3) { if (!isBool(arguments[2])) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument must be a boolean (scale)"); @@ -192,6 +203,8 @@ public: ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { + const size_t number_of_arguments = arguments.size(); + ColumnPtr col1 = arguments[0].column->convertToFullColumnIfConst(); ColumnPtr col2 = arguments[1].column->convertToFullColumnIfConst(); @@ -210,7 +223,7 @@ public: // Handle third argument for scale (if passed, otherwise default to true) bool scale = true; - if (arguments.size() == 3) + if (number_of_arguments == 3) { scale = arguments[2].column->getBool(0); // Assumes it's a scalar boolean column } From e0fc95c894fcd129566b7c374496423e30f94c2c Mon Sep 17 00:00:00 2001 From: Gabriel Mendes Date: Wed, 18 Sep 2024 11:12:30 -0300 Subject: [PATCH 07/11] remove trailing spaces --- src/Functions/array/arrayAUC.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Functions/array/arrayAUC.cpp b/src/Functions/array/arrayAUC.cpp index 4f37617ab79..04ebb6d5bac 100644 --- a/src/Functions/array/arrayAUC.cpp +++ b/src/Functions/array/arrayAUC.cpp @@ -204,7 +204,7 @@ public: ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { const size_t number_of_arguments = arguments.size(); - + ColumnPtr col1 = arguments[0].column->convertToFullColumnIfConst(); ColumnPtr col2 = arguments[1].column->convertToFullColumnIfConst(); From 02fcd90a66ea61dee36db290a089a7cb48142ba4 Mon Sep 17 00:00:00 2001 From: Gabriel Mendes Date: Wed, 18 Sep 2024 14:13:22 -0300 Subject: [PATCH 08/11] address some pr comments --- src/Functions/array/arrayAUC.cpp | 22 ++++++------- .../0_stateless/01064_array_auc.reference | 16 ++++++++++ tests/queries/0_stateless/01064_array_auc.sql | 18 ++++++++++- .../0_stateless/01064_array_auc.stdout-e | 32 +++++++++++++++++++ .../01202_array_auc_special.reference | 9 ++++++ .../0_stateless/01202_array_auc_special.sql | 14 ++++++++ ...rence => 01202_array_auc_special.stdout-e} | 21 ++++-------- .../0_stateless/03237_array_auc_unscaled.sql | 30 ----------------- 8 files changed, 106 insertions(+), 56 deletions(-) create mode 100644 tests/queries/0_stateless/01064_array_auc.stdout-e rename tests/queries/0_stateless/{03237_array_auc_unscaled.reference => 01202_array_auc_special.stdout-e} (50%) delete mode 100644 tests/queries/0_stateless/03237_array_auc_unscaled.sql diff --git a/src/Functions/array/arrayAUC.cpp b/src/Functions/array/arrayAUC.cpp index 04ebb6d5bac..5577a51e198 100644 --- a/src/Functions/array/arrayAUC.cpp +++ b/src/Functions/array/arrayAUC.cpp @@ -87,7 +87,7 @@ private: const IColumn & labels, ColumnArray::Offset current_offset, ColumnArray::Offset next_offset, - bool scale = true) + bool scale) { struct ScoreLabel { @@ -116,10 +116,10 @@ private: size_t curr_fp = 0, curr_tp = 0; for (size_t i = 0; i < size; ++i) { - // Only increment the area when the score changes + /// Only increment the area when the score changes if (sorted_labels[i].score != prev_score) { - area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; // Trapezoidal area under curve (might degenerate to zero or to a rectangle) + area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; /// Trapezoidal area under curve (might degenerate to zero or to a rectangle) prev_fp = curr_fp; prev_tp = curr_tp; prev_score = sorted_labels[i].score; @@ -135,12 +135,12 @@ private: /// Then normalize it, if scale is true, dividing by the area to the area of rectangle. - if (scale && (curr_tp == 0 || curr_tp == size)) - return std::numeric_limits::quiet_NaN(); - if (scale) + { + if (curr_tp == 0 || curr_tp == size) + return std::numeric_limits::quiet_NaN(); return area / curr_tp / (size - curr_tp); - + } return area; } @@ -150,7 +150,7 @@ private: const ColumnArray::Offsets & offsets, PaddedPODArray & result, size_t input_rows_count, - bool scale = true) + bool scale) { result.resize(input_rows_count); @@ -195,7 +195,7 @@ public: if (number_of_arguments == 3) { if (!isBool(arguments[2])) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument must be a boolean (scale)"); + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument (scale) for function {} must be a bool.", getName()); } return std::make_shared(); @@ -221,11 +221,11 @@ public: if (!col_array1->hasEqualOffsets(*col_array2)) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Array arguments for function {} must have equal sizes", getName()); - // Handle third argument for scale (if passed, otherwise default to true) + /// Handle third argument for scale (if passed, otherwise default to true) bool scale = true; if (number_of_arguments == 3) { - scale = arguments[2].column->getBool(0); // Assumes it's a scalar boolean column + scale = arguments[2].column->getBool(0); /// Assumes it's a scalar boolean column } auto col_res = ColumnVector::create(); diff --git a/tests/queries/0_stateless/01064_array_auc.reference b/tests/queries/0_stateless/01064_array_auc.reference index 8c17bba359a..8b5c852a38b 100644 --- a/tests/queries/0_stateless/01064_array_auc.reference +++ b/tests/queries/0_stateless/01064_array_auc.reference @@ -14,3 +14,19 @@ 0.25 0.125 0.25 +3 +3 +3 +3 +3 +3 +3 +3 +3 +1 +1 +1 +1 +1 +1 +1 diff --git a/tests/queries/0_stateless/01064_array_auc.sql b/tests/queries/0_stateless/01064_array_auc.sql index de05c47c51b..94767b72931 100644 --- a/tests/queries/0_stateless/01064_array_auc.sql +++ b/tests/queries/0_stateless/01064_array_auc.sql @@ -13,4 +13,20 @@ select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1]); select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]); select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]); select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]); -select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]); \ No newline at end of file +select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]); +select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], false); +select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), false); +select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), false); +select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))), false); +select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))), false); +select arrayAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1], false); +select arrayAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1], false); +select arrayAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1], false); +select arrayAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1], false); +select arrayAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1], false); +select arrayAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1], false); +select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1], false); +select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1], false); +select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1], false); +select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0], false); +select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0], false); diff --git a/tests/queries/0_stateless/01064_array_auc.stdout-e b/tests/queries/0_stateless/01064_array_auc.stdout-e new file mode 100644 index 00000000000..8b5c852a38b --- /dev/null +++ b/tests/queries/0_stateless/01064_array_auc.stdout-e @@ -0,0 +1,32 @@ +0.75 +0.75 +0.75 +0.75 +0.75 +0.75 +0.75 +0.75 +0.75 +0.25 +0.25 +0.25 +0.25 +0.25 +0.125 +0.25 +3 +3 +3 +3 +3 +3 +3 +3 +3 +1 +1 +1 +1 +1 +1 +1 diff --git a/tests/queries/0_stateless/01202_array_auc_special.reference b/tests/queries/0_stateless/01202_array_auc_special.reference index 8f3f0cf1efe..cb25b381ff7 100644 --- a/tests/queries/0_stateless/01202_array_auc_special.reference +++ b/tests/queries/0_stateless/01202_array_auc_special.reference @@ -7,3 +7,12 @@ nan 0.75 1 0.75 +0 +0 +0 +0.5 +1 +0 +1.5 +2 +1.5 diff --git a/tests/queries/0_stateless/01202_array_auc_special.sql b/tests/queries/0_stateless/01202_array_auc_special.sql index e379050a982..f22524c2756 100644 --- a/tests/queries/0_stateless/01202_array_auc_special.sql +++ b/tests/queries/0_stateless/01202_array_auc_special.sql @@ -12,3 +12,17 @@ SELECT arrayAUC([1, 0], [0, 1]); SELECT arrayAUC([0, 0, 1], [0, 1, 1]); SELECT arrayAUC([0, 1, 1], [0, 1, 1]); SELECT arrayAUC([0, 1, 1], [0, 0, 1]); +SELECT arrayAUC([], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayAUC([1], [1], false); +SELECT arrayAUC([1], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayAUC([], [1], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayAUC([1, 2], [3], false); -- { serverError BAD_ARGUMENTS } +SELECT arrayAUC([1], [2, 3], false); -- { serverError BAD_ARGUMENTS } +SELECT arrayAUC([1, 1], [1, 1], false); +SELECT arrayAUC([1, 1], [0, 0], false); +SELECT arrayAUC([1, 1], [0, 1], false); +SELECT arrayAUC([0, 1], [0, 1], false); +SELECT arrayAUC([1, 0], [0, 1], false); +SELECT arrayAUC([0, 0, 1], [0, 1, 1], false); +SELECT arrayAUC([0, 1, 1], [0, 1, 1], false); +SELECT arrayAUC([0, 1, 1], [0, 0, 1], false); diff --git a/tests/queries/0_stateless/03237_array_auc_unscaled.reference b/tests/queries/0_stateless/01202_array_auc_special.stdout-e similarity index 50% rename from tests/queries/0_stateless/03237_array_auc_unscaled.reference rename to tests/queries/0_stateless/01202_array_auc_special.stdout-e index 63204682fd4..cb25b381ff7 100644 --- a/tests/queries/0_stateless/03237_array_auc_unscaled.reference +++ b/tests/queries/0_stateless/01202_array_auc_special.stdout-e @@ -1,19 +1,12 @@ -3 -3 -3 -3 -3 -3 -3 -3 -3 -1 -1 -1 -1 -1 +nan +nan +nan +0.5 1 +0 +0.75 1 +0.75 0 0 0 diff --git a/tests/queries/0_stateless/03237_array_auc_unscaled.sql b/tests/queries/0_stateless/03237_array_auc_unscaled.sql deleted file mode 100644 index 4083e836067..00000000000 --- a/tests/queries/0_stateless/03237_array_auc_unscaled.sql +++ /dev/null @@ -1,30 +0,0 @@ -select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], false); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), false); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), false); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))), false); -select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))), false); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1], false); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1], false); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1], false); -select arrayAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1], false); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1], false); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1], false); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1], false); -select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1], false); -select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1], false); -select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0], false); -select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0], false); -SELECT arrayAUC([], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUC([1], [1], false); -SELECT arrayAUC([1], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUC([], [1], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayAUC([1, 2], [3], false); -- { serverError BAD_ARGUMENTS } -SELECT arrayAUC([1], [2, 3], false); -- { serverError BAD_ARGUMENTS } -SELECT arrayAUC([1, 1], [1, 1], false); -SELECT arrayAUC([1, 1], [0, 0], false); -SELECT arrayAUC([1, 1], [0, 1], false); -SELECT arrayAUC([0, 1], [0, 1], false); -SELECT arrayAUC([1, 0], [0, 1], false); -SELECT arrayAUC([0, 0, 1], [0, 1, 1], false); -SELECT arrayAUC([0, 1, 1], [0, 1, 1], false); -SELECT arrayAUC([0, 1, 1], [0, 0, 1], false); From 006d14445eb619eb518e8c5f6879d0b41b77cd13 Mon Sep 17 00:00:00 2001 From: Gabriel Mendes Date: Wed, 18 Sep 2024 14:59:58 -0300 Subject: [PATCH 09/11] remove stdout files --- .../0_stateless/01064_array_auc.stdout-e | 32 ------------------- .../01202_array_auc_special.stdout-e | 18 ----------- 2 files changed, 50 deletions(-) delete mode 100644 tests/queries/0_stateless/01064_array_auc.stdout-e delete mode 100644 tests/queries/0_stateless/01202_array_auc_special.stdout-e diff --git a/tests/queries/0_stateless/01064_array_auc.stdout-e b/tests/queries/0_stateless/01064_array_auc.stdout-e deleted file mode 100644 index 8b5c852a38b..00000000000 --- a/tests/queries/0_stateless/01064_array_auc.stdout-e +++ /dev/null @@ -1,32 +0,0 @@ -0.75 -0.75 -0.75 -0.75 -0.75 -0.75 -0.75 -0.75 -0.75 -0.25 -0.25 -0.25 -0.25 -0.25 -0.125 -0.25 -3 -3 -3 -3 -3 -3 -3 -3 -3 -1 -1 -1 -1 -1 -1 -1 diff --git a/tests/queries/0_stateless/01202_array_auc_special.stdout-e b/tests/queries/0_stateless/01202_array_auc_special.stdout-e deleted file mode 100644 index cb25b381ff7..00000000000 --- a/tests/queries/0_stateless/01202_array_auc_special.stdout-e +++ /dev/null @@ -1,18 +0,0 @@ -nan -nan -nan -0.5 -1 -0 -0.75 -1 -0.75 -0 -0 -0 -0.5 -1 -0 -1.5 -2 -1.5 From 7f0b7a915808fe250bbb8da747ae37db5bd096da Mon Sep 17 00:00:00 2001 From: Gabriel Mendes Date: Wed, 18 Sep 2024 15:17:54 -0300 Subject: [PATCH 10/11] add tests to cover all possible flows --- .../0_stateless/01064_array_auc.reference | 16 ++++++++++++++++ tests/queries/0_stateless/01064_array_auc.sql | 16 ++++++++++++++++ .../01202_array_auc_special.reference | 9 +++++++++ .../0_stateless/01202_array_auc_special.sql | 18 ++++++++++++++++++ 4 files changed, 59 insertions(+) diff --git a/tests/queries/0_stateless/01064_array_auc.reference b/tests/queries/0_stateless/01064_array_auc.reference index 8b5c852a38b..3fd5483eb99 100644 --- a/tests/queries/0_stateless/01064_array_auc.reference +++ b/tests/queries/0_stateless/01064_array_auc.reference @@ -14,6 +14,22 @@ 0.25 0.125 0.25 +0.75 +0.75 +0.75 +0.75 +0.75 +0.75 +0.75 +0.75 +0.75 +0.25 +0.25 +0.25 +0.25 +0.25 +0.125 +0.25 3 3 3 diff --git a/tests/queries/0_stateless/01064_array_auc.sql b/tests/queries/0_stateless/01064_array_auc.sql index 94767b72931..adc32abcde8 100644 --- a/tests/queries/0_stateless/01064_array_auc.sql +++ b/tests/queries/0_stateless/01064_array_auc.sql @@ -14,6 +14,22 @@ select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]); select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]); select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]); select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]); +select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], true); +select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), true); +select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), true); +select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))), true); +select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))), true); +select arrayAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1], true); +select arrayAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1], true); +select arrayAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1], true); +select arrayAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1], true); +select arrayAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1], true); +select arrayAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1], true); +select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1], true); +select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1], true); +select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1], true); +select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0], true); +select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0], true); select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], false); select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), false); select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), false); diff --git a/tests/queries/0_stateless/01202_array_auc_special.reference b/tests/queries/0_stateless/01202_array_auc_special.reference index cb25b381ff7..8f4b9495a5c 100644 --- a/tests/queries/0_stateless/01202_array_auc_special.reference +++ b/tests/queries/0_stateless/01202_array_auc_special.reference @@ -7,6 +7,15 @@ nan 0.75 1 0.75 +nan +nan +nan +0.5 +1 +0 +0.75 +1 +0.75 0 0 0 diff --git a/tests/queries/0_stateless/01202_array_auc_special.sql b/tests/queries/0_stateless/01202_array_auc_special.sql index f22524c2756..a7276ec0620 100644 --- a/tests/queries/0_stateless/01202_array_auc_special.sql +++ b/tests/queries/0_stateless/01202_array_auc_special.sql @@ -12,6 +12,20 @@ SELECT arrayAUC([1, 0], [0, 1]); SELECT arrayAUC([0, 0, 1], [0, 1, 1]); SELECT arrayAUC([0, 1, 1], [0, 1, 1]); SELECT arrayAUC([0, 1, 1], [0, 0, 1]); +SELECT arrayAUC([], [], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayAUC([1], [1], true); +SELECT arrayAUC([1], [], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayAUC([], [1], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayAUC([1, 2], [3], true); -- { serverError BAD_ARGUMENTS } +SELECT arrayAUC([1], [2, 3], true); -- { serverError BAD_ARGUMENTS } +SELECT arrayAUC([1, 1], [1, 1], true); +SELECT arrayAUC([1, 1], [0, 0], true); +SELECT arrayAUC([1, 1], [0, 1], true); +SELECT arrayAUC([0, 1], [0, 1], true); +SELECT arrayAUC([1, 0], [0, 1], true); +SELECT arrayAUC([0, 0, 1], [0, 1, 1], true); +SELECT arrayAUC([0, 1, 1], [0, 1, 1], true); +SELECT arrayAUC([0, 1, 1], [0, 0, 1], true); SELECT arrayAUC([], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } SELECT arrayAUC([1], [1], false); SELECT arrayAUC([1], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } @@ -26,3 +40,7 @@ SELECT arrayAUC([1, 0], [0, 1], false); SELECT arrayAUC([0, 0, 1], [0, 1, 1], false); SELECT arrayAUC([0, 1, 1], [0, 1, 1], false); SELECT arrayAUC([0, 1, 1], [0, 0, 1], false); +SELECT arrayAUC([0, 1, 1], [0, 0, 1], false, true); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +SELECT arrayAUC([0, 1, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +SELECT arrayAUC([0, 1, 1], [0, 0, 1], 'false'); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayAUC([0, 1, 1], [0, 0, 1], 4); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } \ No newline at end of file From bb6db8926e52895e32f3c978e05ee45b4b404cfb Mon Sep 17 00:00:00 2001 From: Robert Schulze Date: Wed, 18 Sep 2024 20:48:36 +0000 Subject: [PATCH 11/11] Some fixups --- src/Functions/array/arrayAUC.cpp | 23 ++++++++----------- tests/queries/0_stateless/01064_array_auc.sql | 8 +++++++ 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/Functions/array/arrayAUC.cpp b/src/Functions/array/arrayAUC.cpp index 5577a51e198..68cc292e0e7 100644 --- a/src/Functions/array/arrayAUC.cpp +++ b/src/Functions/array/arrayAUC.cpp @@ -171,9 +171,9 @@ public: bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; } - DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { - const size_t number_of_arguments = arguments.size(); + size_t number_of_arguments = arguments.size(); if (number_of_arguments < 2 || number_of_arguments > 3) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, @@ -182,20 +182,19 @@ public: for (size_t i = 0; i < 2; ++i) { - const DataTypeArray * array_type = checkAndGetDataType(arguments[i].get()); + const DataTypeArray * array_type = checkAndGetDataType(arguments[i].type.get()); if (!array_type) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The two first arguments for function {} must be an array.", getName()); + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The two first arguments for function {} must be of type Array.", getName()); const auto & nested_type = array_type->getNestedType(); if (!isNativeNumber(nested_type) && !isEnum(nested_type)) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{} cannot process values of type {}", - getName(), nested_type->getName()); + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{} cannot process values of type {}", getName(), nested_type->getName()); } if (number_of_arguments == 3) { - if (!isBool(arguments[2])) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument (scale) for function {} must be a bool.", getName()); + if (!isBool(arguments[2].type) || arguments[2].column.get() == nullptr || !isColumnConst(*arguments[2].column)) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument (scale) for function {} must be of type const Bool.", getName()); } return std::make_shared(); @@ -203,7 +202,7 @@ public: ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { - const size_t number_of_arguments = arguments.size(); + size_t number_of_arguments = arguments.size(); ColumnPtr col1 = arguments[0].column->convertToFullColumnIfConst(); ColumnPtr col2 = arguments[1].column->convertToFullColumnIfConst(); @@ -223,10 +222,8 @@ public: /// Handle third argument for scale (if passed, otherwise default to true) bool scale = true; - if (number_of_arguments == 3) - { - scale = arguments[2].column->getBool(0); /// Assumes it's a scalar boolean column - } + if (number_of_arguments == 3 && input_rows_count > 0) + scale = arguments[2].column->getBool(0); auto col_res = ColumnVector::create(); diff --git a/tests/queries/0_stateless/01064_array_auc.sql b/tests/queries/0_stateless/01064_array_auc.sql index adc32abcde8..5594b505223 100644 --- a/tests/queries/0_stateless/01064_array_auc.sql +++ b/tests/queries/0_stateless/01064_array_auc.sql @@ -14,6 +14,7 @@ select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]); select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]); select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]); select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]); + select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], true); select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), true); select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), true); @@ -30,6 +31,7 @@ select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1], true); select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1], true); select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0], true); select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0], true); + select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], false); select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), false); select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), false); @@ -46,3 +48,9 @@ select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1], false) select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1], false); select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0], false); select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0], false); + +-- negative tests +select arrayAUC([0, 0, 1, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +select arrayAUC([0.1, 0.35], [0, 0, 1, 1]); -- { serverError BAD_ARGUMENTS } +select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], materialize(true)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], true, true); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }