Merge pull request #69717 from gabrielmcg44/add-array-unscaled

Allow `arrayAUC` without scaling
This commit is contained in:
Robert Schulze 2024-09-19 08:37:18 +00:00 committed by GitHub
commit 396abf7636
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 173 additions and 21 deletions

View File

@ -2088,13 +2088,14 @@ Calculate AUC (Area Under the Curve, which is a concept in machine learning, see
**Syntax**
``` sql
arrayAUC(arr_scores, arr_labels)
arrayAUC(arr_scores, arr_labels[, scale])
```
**Arguments**
- `arr_scores` — scores prediction model gives.
- `arr_labels` — labels of samples, usually 1 for positive sample and 0 for negative sample.
- `scale` - Optional. Wether to return the normalized area. Default value: true. [Bool]
**Returned value**

View File

@ -14,6 +14,7 @@ namespace ErrorCodes
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int BAD_ARGUMENTS;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
@ -85,7 +86,8 @@ private:
const IColumn & scores,
const IColumn & labels,
ColumnArray::Offset current_offset,
ColumnArray::Offset next_offset)
ColumnArray::Offset next_offset,
bool scale)
{
struct ScoreLabel
{
@ -114,10 +116,10 @@ private:
size_t curr_fp = 0, curr_tp = 0;
for (size_t i = 0; i < size; ++i)
{
// Only increment the area when the score changes
/// Only increment the area when the score changes
if (sorted_labels[i].score != prev_score)
{
area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; // Trapezoidal area under curve (might degenerate to zero or to a rectangle)
area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; /// Trapezoidal area under curve (might degenerate to zero or to a rectangle)
prev_fp = curr_fp;
prev_tp = curr_tp;
prev_score = sorted_labels[i].score;
@ -131,12 +133,15 @@ private:
area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0;
/// Then normalize it dividing by the area to the area of rectangle.
/// Then normalize it, if scale is true, dividing by the area to the area of rectangle.
if (curr_tp == 0 || curr_tp == size)
return std::numeric_limits<Float64>::quiet_NaN();
return area / curr_tp / (size - curr_tp);
if (scale)
{
if (curr_tp == 0 || curr_tp == size)
return std::numeric_limits<Float64>::quiet_NaN();
return area / curr_tp / (size - curr_tp);
}
return area;
}
static void vector(
@ -144,7 +149,8 @@ private:
const IColumn & labels,
const ColumnArray::Offsets & offsets,
PaddedPODArray<Float64> & result,
size_t input_rows_count)
size_t input_rows_count,
bool scale)
{
result.resize(input_rows_count);
@ -152,28 +158,43 @@ private:
for (size_t i = 0; i < input_rows_count; ++i)
{
auto next_offset = offsets[i];
result[i] = apply(scores, labels, current_offset, next_offset);
result[i] = apply(scores, labels, current_offset, next_offset, scale);
current_offset = next_offset;
}
}
public:
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return false; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
for (size_t i = 0; i < getNumberOfArguments(); ++i)
size_t number_of_arguments = arguments.size();
if (number_of_arguments < 2 || number_of_arguments > 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Number of arguments for function {} doesn't match: passed {}, should be 2 or 3",
getName(), number_of_arguments);
for (size_t i = 0; i < 2; ++i)
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[i].get());
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[i].type.get());
if (!array_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "All arguments for function {} must be an array.", getName());
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The two first arguments for function {} must be of type Array.", getName());
const auto & nested_type = array_type->getNestedType();
if (!isNativeNumber(nested_type) && !isEnum(nested_type))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{} cannot process values of type {}",
getName(), nested_type->getName());
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{} cannot process values of type {}", getName(), nested_type->getName());
}
if (number_of_arguments == 3)
{
if (!isBool(arguments[2].type) || arguments[2].column.get() == nullptr || !isColumnConst(*arguments[2].column))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument (scale) for function {} must be of type const Bool.", getName());
}
return std::make_shared<DataTypeFloat64>();
@ -181,6 +202,8 @@ public:
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
size_t number_of_arguments = arguments.size();
ColumnPtr col1 = arguments[0].column->convertToFullColumnIfConst();
ColumnPtr col2 = arguments[1].column->convertToFullColumnIfConst();
@ -197,6 +220,11 @@ public:
if (!col_array1->hasEqualOffsets(*col_array2))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Array arguments for function {} must have equal sizes", getName());
/// Handle third argument for scale (if passed, otherwise default to true)
bool scale = true;
if (number_of_arguments == 3 && input_rows_count > 0)
scale = arguments[2].column->getBool(0);
auto col_res = ColumnVector<Float64>::create();
vector(
@ -204,7 +232,8 @@ public:
col_array2->getData(),
col_array1->getOffsets(),
col_res->getData(),
input_rows_count);
input_rows_count,
scale);
return col_res;
}

View File

@ -14,3 +14,35 @@
0.25
0.125
0.25
0.75
0.75
0.75
0.75
0.75
0.75
0.75
0.75
0.75
0.25
0.25
0.25
0.25
0.25
0.125
0.25
3
3
3
3
3
3
3
3
3
1
1
1
1
1
1
1

View File

@ -13,4 +13,44 @@ select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1]);
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]);
select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]);
select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]);
select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]);
select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]);
select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], true);
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), true);
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), true);
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))), true);
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))), true);
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1], true);
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1], true);
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1], true);
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1], true);
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1], true);
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1], true);
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1], true);
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1], true);
select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1], true);
select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0], true);
select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0], true);
select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], false);
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), false);
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), false);
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))), false);
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))), false);
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1], false);
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1], false);
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1], false);
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1], false);
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1], false);
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1], false);
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1], false);
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1], false);
select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1], false);
select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0], false);
select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0], false);
-- negative tests
select arrayAUC([0, 0, 1, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
select arrayAUC([0.1, 0.35], [0, 0, 1, 1]); -- { serverError BAD_ARGUMENTS }
select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], materialize(true)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], true, true); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }

View File

@ -7,3 +7,21 @@ nan
0.75
1
0.75
nan
nan
nan
0.5
1
0
0.75
1
0.75
0
0
0
0.5
1
0
1.5
2
1.5

View File

@ -12,3 +12,35 @@ SELECT arrayAUC([1, 0], [0, 1]);
SELECT arrayAUC([0, 0, 1], [0, 1, 1]);
SELECT arrayAUC([0, 1, 1], [0, 1, 1]);
SELECT arrayAUC([0, 1, 1], [0, 0, 1]);
SELECT arrayAUC([], [], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayAUC([1], [1], true);
SELECT arrayAUC([1], [], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayAUC([], [1], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayAUC([1, 2], [3], true); -- { serverError BAD_ARGUMENTS }
SELECT arrayAUC([1], [2, 3], true); -- { serverError BAD_ARGUMENTS }
SELECT arrayAUC([1, 1], [1, 1], true);
SELECT arrayAUC([1, 1], [0, 0], true);
SELECT arrayAUC([1, 1], [0, 1], true);
SELECT arrayAUC([0, 1], [0, 1], true);
SELECT arrayAUC([1, 0], [0, 1], true);
SELECT arrayAUC([0, 0, 1], [0, 1, 1], true);
SELECT arrayAUC([0, 1, 1], [0, 1, 1], true);
SELECT arrayAUC([0, 1, 1], [0, 0, 1], true);
SELECT arrayAUC([], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayAUC([1], [1], false);
SELECT arrayAUC([1], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayAUC([], [1], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayAUC([1, 2], [3], false); -- { serverError BAD_ARGUMENTS }
SELECT arrayAUC([1], [2, 3], false); -- { serverError BAD_ARGUMENTS }
SELECT arrayAUC([1, 1], [1, 1], false);
SELECT arrayAUC([1, 1], [0, 0], false);
SELECT arrayAUC([1, 1], [0, 1], false);
SELECT arrayAUC([0, 1], [0, 1], false);
SELECT arrayAUC([1, 0], [0, 1], false);
SELECT arrayAUC([0, 0, 1], [0, 1, 1], false);
SELECT arrayAUC([0, 1, 1], [0, 1, 1], false);
SELECT arrayAUC([0, 1, 1], [0, 0, 1], false);
SELECT arrayAUC([0, 1, 1], [0, 0, 1], false, true); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT arrayAUC([0, 1, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT arrayAUC([0, 1, 1], [0, 0, 1], 'false'); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayAUC([0, 1, 1], [0, 0, 1], 4); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }