mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-25 17:12:03 +00:00
Merge pull request #69717 from gabrielmcg44/add-array-unscaled
Allow `arrayAUC` without scaling
This commit is contained in:
commit
396abf7636
@ -2088,13 +2088,14 @@ Calculate AUC (Area Under the Curve, which is a concept in machine learning, see
|
||||
**Syntax**
|
||||
|
||||
``` sql
|
||||
arrayAUC(arr_scores, arr_labels)
|
||||
arrayAUC(arr_scores, arr_labels[, scale])
|
||||
```
|
||||
|
||||
**Arguments**
|
||||
|
||||
- `arr_scores` — scores prediction model gives.
|
||||
- `arr_labels` — labels of samples, usually 1 for positive sample and 0 for negative sample.
|
||||
- `scale` - Optional. Wether to return the normalized area. Default value: true. [Bool]
|
||||
|
||||
**Returned value**
|
||||
|
||||
|
@ -14,6 +14,7 @@ namespace ErrorCodes
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int ILLEGAL_COLUMN;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
}
|
||||
|
||||
|
||||
@ -85,7 +86,8 @@ private:
|
||||
const IColumn & scores,
|
||||
const IColumn & labels,
|
||||
ColumnArray::Offset current_offset,
|
||||
ColumnArray::Offset next_offset)
|
||||
ColumnArray::Offset next_offset,
|
||||
bool scale)
|
||||
{
|
||||
struct ScoreLabel
|
||||
{
|
||||
@ -114,10 +116,10 @@ private:
|
||||
size_t curr_fp = 0, curr_tp = 0;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
// Only increment the area when the score changes
|
||||
/// Only increment the area when the score changes
|
||||
if (sorted_labels[i].score != prev_score)
|
||||
{
|
||||
area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; // Trapezoidal area under curve (might degenerate to zero or to a rectangle)
|
||||
area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0; /// Trapezoidal area under curve (might degenerate to zero or to a rectangle)
|
||||
prev_fp = curr_fp;
|
||||
prev_tp = curr_tp;
|
||||
prev_score = sorted_labels[i].score;
|
||||
@ -131,20 +133,24 @@ private:
|
||||
|
||||
area += (curr_fp - prev_fp) * (curr_tp + prev_tp) / 2.0;
|
||||
|
||||
/// Then normalize it dividing by the area to the area of rectangle.
|
||||
/// Then normalize it, if scale is true, dividing by the area to the area of rectangle.
|
||||
|
||||
if (scale)
|
||||
{
|
||||
if (curr_tp == 0 || curr_tp == size)
|
||||
return std::numeric_limits<Float64>::quiet_NaN();
|
||||
|
||||
return area / curr_tp / (size - curr_tp);
|
||||
}
|
||||
return area;
|
||||
}
|
||||
|
||||
static void vector(
|
||||
const IColumn & scores,
|
||||
const IColumn & labels,
|
||||
const ColumnArray::Offsets & offsets,
|
||||
PaddedPODArray<Float64> & result,
|
||||
size_t input_rows_count)
|
||||
size_t input_rows_count,
|
||||
bool scale)
|
||||
{
|
||||
result.resize(input_rows_count);
|
||||
|
||||
@ -152,28 +158,43 @@ private:
|
||||
for (size_t i = 0; i < input_rows_count; ++i)
|
||||
{
|
||||
auto next_offset = offsets[i];
|
||||
result[i] = apply(scores, labels, current_offset, next_offset);
|
||||
result[i] = apply(scores, labels, current_offset, next_offset, scale);
|
||||
current_offset = next_offset;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
String getName() const override { return name; }
|
||||
size_t getNumberOfArguments() const override { return 2; }
|
||||
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return false; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
bool isVariadic() const override { return true; }
|
||||
size_t getNumberOfArguments() const override { return 0; }
|
||||
|
||||
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo &) const override { return true; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
|
||||
{
|
||||
for (size_t i = 0; i < getNumberOfArguments(); ++i)
|
||||
size_t number_of_arguments = arguments.size();
|
||||
|
||||
if (number_of_arguments < 2 || number_of_arguments > 3)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Number of arguments for function {} doesn't match: passed {}, should be 2 or 3",
|
||||
getName(), number_of_arguments);
|
||||
|
||||
for (size_t i = 0; i < 2; ++i)
|
||||
{
|
||||
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[i].get());
|
||||
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[i].type.get());
|
||||
if (!array_type)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "All arguments for function {} must be an array.", getName());
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The two first arguments for function {} must be of type Array.", getName());
|
||||
|
||||
const auto & nested_type = array_type->getNestedType();
|
||||
if (!isNativeNumber(nested_type) && !isEnum(nested_type))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{} cannot process values of type {}",
|
||||
getName(), nested_type->getName());
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "{} cannot process values of type {}", getName(), nested_type->getName());
|
||||
}
|
||||
|
||||
if (number_of_arguments == 3)
|
||||
{
|
||||
if (!isBool(arguments[2].type) || arguments[2].column.get() == nullptr || !isColumnConst(*arguments[2].column))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Third argument (scale) for function {} must be of type const Bool.", getName());
|
||||
}
|
||||
|
||||
return std::make_shared<DataTypeFloat64>();
|
||||
@ -181,6 +202,8 @@ public:
|
||||
|
||||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
|
||||
{
|
||||
size_t number_of_arguments = arguments.size();
|
||||
|
||||
ColumnPtr col1 = arguments[0].column->convertToFullColumnIfConst();
|
||||
ColumnPtr col2 = arguments[1].column->convertToFullColumnIfConst();
|
||||
|
||||
@ -197,6 +220,11 @@ public:
|
||||
if (!col_array1->hasEqualOffsets(*col_array2))
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Array arguments for function {} must have equal sizes", getName());
|
||||
|
||||
/// Handle third argument for scale (if passed, otherwise default to true)
|
||||
bool scale = true;
|
||||
if (number_of_arguments == 3 && input_rows_count > 0)
|
||||
scale = arguments[2].column->getBool(0);
|
||||
|
||||
auto col_res = ColumnVector<Float64>::create();
|
||||
|
||||
vector(
|
||||
@ -204,7 +232,8 @@ public:
|
||||
col_array2->getData(),
|
||||
col_array1->getOffsets(),
|
||||
col_res->getData(),
|
||||
input_rows_count);
|
||||
input_rows_count,
|
||||
scale);
|
||||
|
||||
return col_res;
|
||||
}
|
||||
|
@ -14,3 +14,35 @@
|
||||
0.25
|
||||
0.125
|
||||
0.25
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.25
|
||||
0.25
|
||||
0.25
|
||||
0.25
|
||||
0.25
|
||||
0.125
|
||||
0.25
|
||||
3
|
||||
3
|
||||
3
|
||||
3
|
||||
3
|
||||
3
|
||||
3
|
||||
3
|
||||
3
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
|
@ -14,3 +14,43 @@ select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]);
|
||||
select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]);
|
||||
select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]);
|
||||
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], true);
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), true);
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), true);
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))), true);
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))), true);
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1], true);
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1], true);
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1], true);
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1], true);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1], true);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1], true);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1], true);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1], true);
|
||||
select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1], true);
|
||||
select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0], true);
|
||||
select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0], true);
|
||||
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], false);
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)), false);
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)), false);
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))), false);
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))), false);
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1], false);
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1], false);
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1], false);
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1], false);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1], false);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1], false);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1], false);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1], false);
|
||||
select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1], false);
|
||||
select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0], false);
|
||||
select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0], false);
|
||||
|
||||
-- negative tests
|
||||
select arrayAUC([0, 0, 1, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
|
||||
select arrayAUC([0.1, 0.35], [0, 0, 1, 1]); -- { serverError BAD_ARGUMENTS }
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], materialize(true)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1], true, true); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
|
||||
|
@ -7,3 +7,21 @@ nan
|
||||
0.75
|
||||
1
|
||||
0.75
|
||||
nan
|
||||
nan
|
||||
nan
|
||||
0.5
|
||||
1
|
||||
0
|
||||
0.75
|
||||
1
|
||||
0.75
|
||||
0
|
||||
0
|
||||
0
|
||||
0.5
|
||||
1
|
||||
0
|
||||
1.5
|
||||
2
|
||||
1.5
|
||||
|
@ -12,3 +12,35 @@ SELECT arrayAUC([1, 0], [0, 1]);
|
||||
SELECT arrayAUC([0, 0, 1], [0, 1, 1]);
|
||||
SELECT arrayAUC([0, 1, 1], [0, 1, 1]);
|
||||
SELECT arrayAUC([0, 1, 1], [0, 0, 1]);
|
||||
SELECT arrayAUC([], [], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayAUC([1], [1], true);
|
||||
SELECT arrayAUC([1], [], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayAUC([], [1], true); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayAUC([1, 2], [3], true); -- { serverError BAD_ARGUMENTS }
|
||||
SELECT arrayAUC([1], [2, 3], true); -- { serverError BAD_ARGUMENTS }
|
||||
SELECT arrayAUC([1, 1], [1, 1], true);
|
||||
SELECT arrayAUC([1, 1], [0, 0], true);
|
||||
SELECT arrayAUC([1, 1], [0, 1], true);
|
||||
SELECT arrayAUC([0, 1], [0, 1], true);
|
||||
SELECT arrayAUC([1, 0], [0, 1], true);
|
||||
SELECT arrayAUC([0, 0, 1], [0, 1, 1], true);
|
||||
SELECT arrayAUC([0, 1, 1], [0, 1, 1], true);
|
||||
SELECT arrayAUC([0, 1, 1], [0, 0, 1], true);
|
||||
SELECT arrayAUC([], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayAUC([1], [1], false);
|
||||
SELECT arrayAUC([1], [], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayAUC([], [1], false); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayAUC([1, 2], [3], false); -- { serverError BAD_ARGUMENTS }
|
||||
SELECT arrayAUC([1], [2, 3], false); -- { serverError BAD_ARGUMENTS }
|
||||
SELECT arrayAUC([1, 1], [1, 1], false);
|
||||
SELECT arrayAUC([1, 1], [0, 0], false);
|
||||
SELECT arrayAUC([1, 1], [0, 1], false);
|
||||
SELECT arrayAUC([0, 1], [0, 1], false);
|
||||
SELECT arrayAUC([1, 0], [0, 1], false);
|
||||
SELECT arrayAUC([0, 0, 1], [0, 1, 1], false);
|
||||
SELECT arrayAUC([0, 1, 1], [0, 1, 1], false);
|
||||
SELECT arrayAUC([0, 1, 1], [0, 0, 1], false);
|
||||
SELECT arrayAUC([0, 1, 1], [0, 0, 1], false, true); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
|
||||
SELECT arrayAUC([0, 1, 1]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
|
||||
SELECT arrayAUC([0, 1, 1], [0, 0, 1], 'false'); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayAUC([0, 1, 1], [0, 0, 1], 4); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
Loading…
Reference in New Issue
Block a user