initial commit, tested function

This commit is contained in:
Gabriel Mendes 2024-09-18 05:15:57 -03:00
parent 2cef99c311
commit 2218ebebbf
No known key found for this signature in database
10 changed files with 378 additions and 0 deletions

View File

@ -2116,6 +2116,41 @@ Result:
└───────────────────────────────────────────────┘
```
## arrayAUC
Calculate unscaled AUC (Area Under the Curve, which is a concept in machine learning, see more details: <https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve>), i.e. without dividing it by total true positives and total false positives.
**Syntax**
``` sql
arrayAUCUnscaled(arr_scores, arr_labels)
```
**Arguments**
- `arr_scores` — scores prediction model gives.
- `arr_labels` — labels of samples, usually 1 for positive sample and 0 for negative sample.
**Returned value**
Returns unscaled AUC value with type Float64.
**Example**
Query:
``` sql
select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]);
```
Result:
``` text
┌─arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐
│ 3.0 │
└───────────────────────────────────────────────────────┘
```
## arrayMap(func, arr1, ...)
Returns an array obtained from the original arrays by application of `func(arr1[i], ..., arrN[i])` for each element. Arrays `arr1` ... `arrN` must have the same number of elements.

View File

@ -1654,6 +1654,43 @@ SELECT arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]);
└────────────────────────────────────────---──┘
```
## arrayAUCUnscaled {#arrayaucunscaled}
Вычисляет площадь под кривой без нормализации.
**Синтаксис**
``` sql
arrayAUCUnscaled(arr_scores, arr_labels)
```
**Аргументы**
- `arr_scores` — оценка, которую дает модель предсказания.
- `arr_labels` — ярлыки выборок, обычно 1 для содержательных выборок и 0 для бессодержательных выборок.
**Возвращаемое значение**
Значение площади под кривой без нормализации.
Тип данных: `Float64`.
**Пример**
Запрос:
``` sql
SELECT arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]);
```
Результат:
``` text
┌─arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐
│ 3.0 │
└────────────────────────────────────────---────────────┘
```
## arrayProduct {#arrayproduct}
Возвращает произведение элементов [массива](../../sql-reference/data-types/array.md).

View File

@ -1221,6 +1221,41 @@ select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]);
└───────────────────────────────────────────────┘
```
## arrayAUCUnscaled {#arrayaucunscaled}
计算没有归一化的AUC (ROC曲线下的面积这是机器学习中的一个概念更多细节请查看https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve)。
**语法**
``` sql
arrayAUCUnscaled(arr_scores, arr_labels)
```
**参数**
- `arr_scores` — 分数预测模型给出。
- `arr_labels` — 样本的标签,通常为 1 表示正样本0 表示负样本。
**返回值**
返回 Float64 类型的非标准化 AUC 值。
**示例**
查询语句:
``` sql
select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]);
```
结果:
``` text
┌─arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐
│ 3.0 │
└───────────────────────────────────────────────────────┘
```
## arrayMap(func, arr1, ...) {#array-map}
将从 `func` 函数的原始应用中获得的数组返回给 `arr` 数组中的每个元素。

View File

@ -0,0 +1,212 @@
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeArray.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnArray.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionFactory.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int BAD_ARGUMENTS;
}
/** The function takes two arrays: scores and labels.
* Label can be one of two values: positive and negative.
* Score can be arbitrary number.
*
* These values are considered as the output of classifier. We have some true labels for objects.
* And classifier assigns some scores to objects that predict these labels in the following way:
* - we can define arbitrary threshold on score and predict that the label is positive if the score is greater than the threshold:
*
* f(object) = score
* predicted_label = score > threshold
*
* This way classifier may predict positive or negative value correctly - true positive or true negative
* or have false positive or false negative result.
* Verying the threshold we can get different probabilities of false positive or false negatives or true positives, etc...
*
* We can also calculate the True Positive Rate and the False Positive Rate:
*
* TPR (also called "sensitivity", "recall" or "probability of detection")
* is the probability of classifier to give positive result if the object has positive label:
* TPR = P(score > threshold | label = positive)
*
* FPR is the probability of classifier to give positive result if the object has negative label:
* FPR = P(score > threshold | label = negative)
*
* We can draw a curve of values of FPR and TPR with different threshold on [0..1] x [0..1] unit square.
* This curve is named "ROC curve" (Receiver Operating Characteristic).
*
* For ROC we can calculate, literally, Area Under the Curve, that will be in the range of [0..1].
* The higher the AUC the better the classifier.
*
* AUC also is as the probability that the score for positive label is greater than the score for negative label.
*
* https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc
* https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve
*
* To calculate AUC, we will draw points of (FPR, TPR) for different thresholds = score_i.
* FPR_raw = countIf(score > score_i, label = negative) = count negative labels above certain score
* TPR_raw = countIf(score > score_i, label = positive) = count positive labels above certain score
*
* Let's look at the example:
* arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]);
*
* 1. We have pairs: (-, 0.1), (-, 0.4), (+, 0.35), (+, 0.8)
*
* 2. Let's sort by score: (-, 0.1), (+, 0.35), (-, 0.4), (+, 0.8)
*
* 3. Let's draw the points:
*
* threshold = 0, TPR_raw = 2, FPR_raw = 2
* threshold = 0.1, TPR_raw = 2, FPR_raw = 1
* threshold = 0.35, TPR_raw = 1, FPR_raw = 1
* threshold = 0.4, TPR_raw = 1, FPR_raw = 0
* threshold = 0.8, TPR_raw = 0, FPR_raw = 0
*
* The "curve" will be present by a line that moves one step either towards right or top on each threshold change.
*/
class FunctionArrayAUCUnscaled : public IFunction
{
public:
static constexpr auto name = "arrayAUCUnscaled";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayAUCUnscaled>(); }
private:
static Float64 apply(
const IColumn & scores,
const IColumn & labels,
ColumnArray::Offset current_offset,
ColumnArray::Offset next_offset)
{
struct ScoreLabel
{
Float64 score;
bool label;
};
size_t size = next_offset - current_offset;
PODArrayWithStackMemory<ScoreLabel, 1024> sorted_labels(size);
for (size_t i = 0; i < size; ++i)
{
bool label = labels.getFloat64(current_offset + i) > 0;
sorted_labels[i].score = scores.getFloat64(current_offset + i);
sorted_labels[i].label = label;
}
/// Sorting scores in descending order to traverse the ROC curve from left to right
std::sort(sorted_labels.begin(), sorted_labels.end(), [](const auto & lhs, const auto & rhs) { return lhs.score > rhs.score; });
Float64 area = 0.0;
Float64 prev_score = sorted_labels[0].score;
size_t prev_fp = 0, prev_tp = 0;
size_t curr_fp = 0, curr_tp = 0;
for (size_t i = 0; i < size; ++i)
{
// Only increment the area when the score changes
if (sorted_labels[i].score != prev_score)
{
area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; // Trapezoidal area under curve (might degenerate to zero or to a rectangle)
prev_fp = curr_fp;
prev_tp = curr_tp;
prev_score = sorted_labels[i].score;
}
if (sorted_labels[i].label)
curr_tp += 1; /// The curve moves one step up.
else
curr_fp += 1; /// The curve moves one step right.
}
area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0;
return area;
}
static void vector(
const IColumn & scores,
const IColumn & labels,
const ColumnArray::Offsets & offsets,
PaddedPODArray<Float64> & result,
size_t input_rows_count)
{
result.resize(input_rows_count);
ColumnArray::Offset current_offset = 0;
for (size_t i = 0; i < input_rows_count; ++i)
{
auto next_offset = offsets[i];
result[i] = apply(scores, labels, current_offset, next_offset);
current_offset = next_offset;
}
}
public:
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return false; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
for (size_t i = 0; i < getNumberOfArguments(); ++i)
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[i].get());
if (!array_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "All arguments for function {} must be an array.", getName());
const auto & nested_type = array_type->getNestedType();
if (!isNativeNumber(nested_type) && !isEnum(nested_type))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{} cannot process values of type {}",
getName(), nested_type->getName());
}
return std::make_shared<DataTypeFloat64>();
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
ColumnPtr col1 = arguments[0].column->convertToFullColumnIfConst();
ColumnPtr col2 = arguments[1].column->convertToFullColumnIfConst();
const ColumnArray * col_array1 = checkAndGetColumn<ColumnArray>(col1.get());
if (!col_array1)
throw Exception(ErrorCodes::ILLEGAL_COLUMN,
"Illegal column {} of first argument of function {}", arguments[0].column->getName(), getName());
const ColumnArray * col_array2 = checkAndGetColumn<ColumnArray>(col2.get());
if (!col_array2)
throw Exception(ErrorCodes::ILLEGAL_COLUMN,
"Illegal column {} of second argument of function {}", arguments[1].column->getName(), getName());
if (!col_array1->hasEqualOffsets(*col_array2))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Array arguments for function {} must have equal sizes", getName());
auto col_res = ColumnVector<Float64>::create();
vector(
col_array1->getData(),
col_array2->getData(),
col_array1->getOffsets(),
col_res->getData(),
input_rows_count);
return col_res;
}
};
REGISTER_FUNCTION(ArrayAUCUnscaled)
{
factory.registerFunction<FunctionArrayAUCUnscaled>();
}
}

View File

@ -1216,6 +1216,7 @@
"argMinState"
"array"
"arrayAUC"
"arrayAUCUnscaled"
"arrayAll"
"arrayAvg"
"arrayCompact"

View File

@ -529,6 +529,7 @@
"argMinState"
"array"
"arrayAUC"
"arrayAUCUnscaled"
"arrayAll"
"arrayAvg"
"arrayCompact"

View File

@ -19,6 +19,7 @@
"Array"
"arrayAll"
"arrayAUC"
"arrayAUCUnscaled"
"arrayCompact"
"arrayConcat"
"arrayCount"

View File

@ -0,0 +1,25 @@
3
3
3
3
3
3
3
3
3
1
1
1
1
1
1
1
0
0
0
0.5
1
0
1.5
2
1.5

View File

@ -0,0 +1,30 @@
select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]);
select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)));
select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)));
select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))));
select arrayAUCUnscaled([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))));
select arrayAUCUnscaled(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1]);
select arrayAUCUnscaled(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1]);
select arrayAUCUnscaled(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1]);
select arrayAUCUnscaled(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1]);
select arrayAUCUnscaled(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1]);
select arrayAUCUnscaled(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1]);
select arrayAUCUnscaled(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1]);
select arrayAUCUnscaled(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]);
select arrayAUCUnscaled(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]);
select arrayAUCUnscaled([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]);
select arrayAUCUnscaled([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]);
SELECT arrayAUCUnscaled([], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayAUCUnscaled([1], [1]);
SELECT arrayAUCUnscaled([1], []); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayAUCUnscaled([], [1]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayAUCUnscaled([1, 2], [3]); -- { serverError BAD_ARGUMENTS }
SELECT arrayAUCUnscaled([1], [2, 3]); -- { serverError BAD_ARGUMENTS }
SELECT arrayAUCUnscaled([1, 1], [1, 1]);
SELECT arrayAUCUnscaled([1, 1], [0, 0]);
SELECT arrayAUCUnscaled([1, 1], [0, 1]);
SELECT arrayAUCUnscaled([0, 1], [0, 1]);
SELECT arrayAUCUnscaled([1, 0], [0, 1]);
SELECT arrayAUCUnscaled([0, 0, 1], [0, 1, 1]);
SELECT arrayAUCUnscaled([0, 1, 1], [0, 1, 1]);
SELECT arrayAUCUnscaled([0, 1, 1], [0, 0, 1]);

View File

@ -1153,6 +1153,7 @@ argMin
argmax
argmin
arrayAUC
arrayAUCUnscaled
arrayAll
arrayAvg
arrayCompact