mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 07:31:57 +00:00
Merge branch 'pr_add_auc' of https://github.com/taiyang-li/ClickHouse into taiyang-li-pr_add_auc
This commit is contained in:
commit
256acebfca
142
dbms/src/Functions/array/arrayScalarProduct.h
Normal file
142
dbms/src/Functions/array/arrayScalarProduct.h
Normal file
@ -0,0 +1,142 @@
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnFixedString.h>
|
||||
#include <Columns/ColumnNullable.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeEnum.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <Functions/IFunction.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Interpreters/castColumn.h>
|
||||
#include <Common/FieldVisitors.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Common/memcmpSmall.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int ILLEGAL_COLUMN;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
}
|
||||
|
||||
template <typename Method, typename Name>
|
||||
class FunctionArrayScalarProduct : public IFunction
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = Name::name;
|
||||
static FunctionPtr create(const Context & context) { return std::make_shared<FunctionArrayScalarProduct>(context); }
|
||||
FunctionArrayScalarProduct(const Context & context_) : context(context_) {}
|
||||
|
||||
private:
|
||||
using ResultColumnType = ColumnVector<typename Method::ResultType>;
|
||||
|
||||
template <typename T>
|
||||
bool executeNumber(Block & block, const ColumnNumbers & arguments, size_t result)
|
||||
{
|
||||
return executeNumberNumber<T, UInt8>(block, arguments, result) || executeNumberNumber<T, UInt16>(block, arguments, result)
|
||||
|| executeNumberNumber<T, UInt32>(block, arguments, result) || executeNumberNumber<T, UInt64>(block, arguments, result)
|
||||
|| executeNumberNumber<T, Int8>(block, arguments, result) || executeNumberNumber<T, Int16>(block, arguments, result)
|
||||
|| executeNumberNumber<T, Int32>(block, arguments, result) || executeNumberNumber<T, Int64>(block, arguments, result)
|
||||
|| executeNumberNumber<T, Float32>(block, arguments, result) || executeNumberNumber<T, Float64>(block, arguments, result);
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename U>
|
||||
bool executeNumberNumber(Block & block, const ColumnNumbers & arguments, size_t result)
|
||||
{
|
||||
ColumnPtr col1 = block.getByPosition(arguments[0]).column->convertToFullColumnIfConst();
|
||||
ColumnPtr col2 = block.getByPosition(arguments[1]).column->convertToFullColumnIfConst();
|
||||
if (!col1 || !col2)
|
||||
return false;
|
||||
|
||||
const ColumnArray * col_array1 = checkAndGetColumn<ColumnArray>(col1.get());
|
||||
const ColumnArray * col_array2 = checkAndGetColumn<ColumnArray>(col2.get());
|
||||
if (!col_array1 || !col_array2)
|
||||
return false;
|
||||
|
||||
const ColumnVector<T> * col_nested1 = checkAndGetColumn<ColumnVector<T>>(col_array1->getData());
|
||||
const ColumnVector<U> * col_nested2 = checkAndGetColumn<ColumnVector<U>>(col_array2->getData());
|
||||
if (!col_nested1 || !col_nested2)
|
||||
return false;
|
||||
|
||||
auto col_res = ResultColumnType::create();
|
||||
vector(col_nested1->getData(), col_array1->getOffsets(), col_nested2->getData(), col_array2->getOffsets(), col_res->getData());
|
||||
block.getByPosition(result).column = std::move(col_res);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
static void vector(
|
||||
const PaddedPODArray<T> & data1,
|
||||
const ColumnArray::Offsets & offsets1,
|
||||
const PaddedPODArray<U> & data2,
|
||||
const ColumnArray::Offsets & offsets2,
|
||||
PaddedPODArray<typename Method::ResultType> & result)
|
||||
{
|
||||
size_t size = offsets1.size();
|
||||
result.resize(size);
|
||||
|
||||
ColumnArray::Offset current_offset1 = 0;
|
||||
ColumnArray::Offset current_offset2 = 0;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
size_t array1_size = offsets1[i] - current_offset1;
|
||||
size_t array2_size = offsets2[i] - current_offset2;
|
||||
result[i] = Method::apply(data1, current_offset1, array1_size, data2, current_offset2, array2_size);
|
||||
|
||||
current_offset1 = offsets1[i];
|
||||
current_offset2 = offsets2[i];
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Get function name.
|
||||
String getName() const override { return name; }
|
||||
|
||||
size_t getNumberOfArguments() const override { return 2; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
{
|
||||
// Basic type check
|
||||
std::vector<DataTypePtr> nested_types(2, nullptr);
|
||||
for (size_t i = 0; i < getNumberOfArguments(); ++i)
|
||||
{
|
||||
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[i].get());
|
||||
if (!array_type)
|
||||
throw Exception("All arguments for function " + getName() + " must be an array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
auto & nested_type = array_type->getNestedType();
|
||||
if (!isNativeNumber(nested_type) && !isEnum(nested_type))
|
||||
throw Exception(
|
||||
getName() + " cannot process values of type " + nested_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
nested_types[i] = nested_type;
|
||||
}
|
||||
|
||||
// Detail type check in Method, then return ReturnType
|
||||
return Method::getReturnType(nested_types[0], nested_types[1]);
|
||||
}
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /* input_rows_count */) override
|
||||
{
|
||||
if (!(executeNumber<UInt8>(block, arguments, result) || executeNumber<UInt16>(block, arguments, result)
|
||||
|| executeNumber<UInt32>(block, arguments, result) || executeNumber<UInt64>(block, arguments, result)
|
||||
|| executeNumber<Int8>(block, arguments, result) || executeNumber<Int16>(block, arguments, result)
|
||||
|| executeNumber<Int32>(block, arguments, result) || executeNumber<Int64>(block, arguments, result)
|
||||
|| executeNumber<Float32>(block, arguments, result) || executeNumber<Float64>(block, arguments, result)))
|
||||
throw Exception{"Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of first argument of function "
|
||||
+ getName(),
|
||||
ErrorCodes::ILLEGAL_COLUMN};
|
||||
}
|
||||
|
||||
private:
|
||||
const Context & context;
|
||||
};
|
||||
|
||||
}
|
133
dbms/src/Functions/array/array_auc.cpp
Normal file
133
dbms/src/Functions/array/array_auc.cpp
Normal file
@ -0,0 +1,133 @@
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include "arrayScalarProduct.h"
|
||||
|
||||
namespace DB
|
||||
{
|
||||
struct NameArrayAUC
|
||||
{
|
||||
static constexpr auto name = "arrayAUC";
|
||||
};
|
||||
|
||||
class ArrayAUCImpl
|
||||
{
|
||||
public:
|
||||
using ResultType = Float64;
|
||||
using LabelValueSet = std::set<Int16>;
|
||||
using LabelValueSets = std::vector<LabelValueSet>;
|
||||
|
||||
inline static const LabelValueSets expect_label_value_sets = {{0, 1}, {-1, 1}};
|
||||
|
||||
struct ScoreLabel
|
||||
{
|
||||
ResultType score;
|
||||
bool label;
|
||||
};
|
||||
|
||||
static DataTypePtr getReturnType(const DataTypePtr & /* score_type */, const DataTypePtr & label_type)
|
||||
{
|
||||
WhichDataType which(label_type);
|
||||
// Labels values are either {0, 1} or {-1, 1}, and its type must be one of (Enum8, UInt8, Int8)
|
||||
if (!which.isUInt8() && !which.isEnum8() && !which.isInt8())
|
||||
{
|
||||
throw Exception(
|
||||
std::string(NameArrayAUC::name) + "lable type must be UInt8, Enum8 or Int8", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
if (which.isEnum8())
|
||||
{
|
||||
auto type8 = checkAndGetDataType<DataTypeEnum8>(label_type.get());
|
||||
if (!type8)
|
||||
throw Exception(std::string(NameArrayAUC::name) + "lable type not valid Enum8", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
LabelValueSet value_set;
|
||||
const auto & values = type8->getValues();
|
||||
for (const auto & value : values)
|
||||
value_set.insert(value.second);
|
||||
|
||||
if (std::find(expect_label_value_sets.begin(), expect_label_value_sets.end(), value_set) == expect_label_value_sets.end())
|
||||
throw Exception(
|
||||
std::string(NameArrayAUC::name) + "lable values must be {0, 1} or {-1, 1}", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
return std::make_shared<DataTypeNumber<ResultType>>();
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
static ResultType apply(
|
||||
const PaddedPODArray<T> & scores,
|
||||
ColumnArray::Offset score_offset,
|
||||
size_t score_len,
|
||||
const PaddedPODArray<U> & labels,
|
||||
ColumnArray::Offset label_offset,
|
||||
size_t label_len)
|
||||
{
|
||||
if (score_len != label_len)
|
||||
throw Exception{"Unmatched length of arrays in " + std::string(NameArrayAUC::name), ErrorCodes::LOGICAL_ERROR};
|
||||
if (score_len == 0)
|
||||
return {};
|
||||
|
||||
// Calculate positive and negative label number and restore scores and labels in vector
|
||||
size_t num_pos = 0;
|
||||
size_t num_neg = 0;
|
||||
LabelValueSet label_value_set;
|
||||
std::vector<ScoreLabel> pairs(score_len);
|
||||
for (size_t i = 0; i < score_len; ++i)
|
||||
{
|
||||
pairs[i].score = scores[i + score_offset];
|
||||
pairs[i].label = (labels[i + label_offset] == 1);
|
||||
if (pairs[i].label)
|
||||
++num_pos;
|
||||
else
|
||||
++num_neg;
|
||||
|
||||
label_value_set.insert(labels[i + label_offset]);
|
||||
}
|
||||
|
||||
// Label values must be {0, 1} or {-1, 1}
|
||||
if (std::find(expect_label_value_sets.begin(), expect_label_value_sets.end(), label_value_set) == expect_label_value_sets.end())
|
||||
throw Exception(
|
||||
std::string(NameArrayAUC::name) + "lable values must be {0, 1} or {-1, 1}", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
// Order pairs of score and lable by score ascending
|
||||
std::sort(pairs.begin(), pairs.end(), [](const auto & lhs, const auto & rhs) { return lhs.score < rhs.score; });
|
||||
|
||||
// Calculate AUC
|
||||
size_t curr_cnt = 0;
|
||||
size_t curr_pos_cnt = 0;
|
||||
Int64 curr_sum = 0;
|
||||
ResultType last_score = -1;
|
||||
ResultType rank_sum = 0;
|
||||
for (size_t i = 0; i < pairs.size(); ++i)
|
||||
{
|
||||
if (pairs[i].score == last_score)
|
||||
{
|
||||
curr_sum += i + 1;
|
||||
++curr_cnt;
|
||||
if (pairs[i].label)
|
||||
++curr_pos_cnt;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (i > 0)
|
||||
rank_sum += ResultType(curr_sum * curr_pos_cnt) / curr_cnt;
|
||||
curr_sum = i + 1;
|
||||
curr_cnt = 1;
|
||||
curr_pos_cnt = pairs[i].label ? 1 : 0;
|
||||
}
|
||||
last_score = pairs[i].score;
|
||||
}
|
||||
rank_sum += ResultType(curr_sum * curr_pos_cnt) / curr_cnt;
|
||||
return (rank_sum - num_pos * (num_pos + 1) / 2) / (num_pos * num_neg);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// auc(array_score, array_label) - Calculate AUC with array of score and label
|
||||
using FunctionArrayAUC = FunctionArrayScalarProduct<ArrayAUCImpl, NameArrayAUC>;
|
||||
|
||||
void registerFunctionArrayAUC(FunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction<FunctionArrayAUC>();
|
||||
}
|
||||
}
|
@ -1,6 +1,5 @@
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class FunctionFactory;
|
||||
|
||||
void registerFunctionArray(FunctionFactory & factory);
|
||||
@ -33,7 +32,7 @@ void registerFunctionArrayDistinct(FunctionFactory & factory);
|
||||
void registerFunctionArrayFlatten(FunctionFactory & factory);
|
||||
void registerFunctionArrayWithConstant(FunctionFactory & factory);
|
||||
void registerFunctionArrayZip(FunctionFactory & factory);
|
||||
|
||||
void registerFunctionArrayAUC(FunctionFactory &);
|
||||
|
||||
void registerFunctionsArray(FunctionFactory & factory)
|
||||
{
|
||||
@ -67,7 +66,7 @@ void registerFunctionsArray(FunctionFactory & factory)
|
||||
registerFunctionArrayFlatten(factory);
|
||||
registerFunctionArrayWithConstant(factory);
|
||||
registerFunctionArrayZip(factory);
|
||||
registerFunctionArrayAUC(factory);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
16
dbms/tests/queries/0_stateless/01064_array_auc.reference
Normal file
16
dbms/tests/queries/0_stateless/01064_array_auc.reference
Normal file
@ -0,0 +1,16 @@
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.25
|
||||
0.25
|
||||
0.25
|
||||
0.25
|
||||
0.25
|
||||
0.125
|
||||
0.25
|
16
dbms/tests/queries/0_stateless/01064_array_auc.sql
Normal file
16
dbms/tests/queries/0_stateless/01064_array_auc.sql
Normal file
@ -0,0 +1,16 @@
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]);
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)));
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)));
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))));
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))));
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]);
|
||||
select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]);
|
||||
select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]);
|
@ -938,5 +938,33 @@ Result:
|
||||
│ [('a','d'),('b','e'),('c','f')] │
|
||||
└────────────────────────────────────────────┘
|
||||
```
|
||||
## arrayAUC {#arrayauc}
|
||||
Calculate AUC(Area Under the Curve, which is a concept in machine learning, see more details: https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve).
|
||||
|
||||
**Syntax**
|
||||
```sql
|
||||
arrayAUC(arr_scores, arr_labels)
|
||||
```
|
||||
|
||||
**Parameters**
|
||||
- `arr_scores` — scores prediction model gives.
|
||||
- `arr_labels` — labels of samples, usually 1 for positive sample and 0 for negtive sample.
|
||||
|
||||
**Returned value**
|
||||
return AUC value with type Float64.
|
||||
|
||||
**Example**
|
||||
Query:
|
||||
```sql
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])
|
||||
```
|
||||
|
||||
Result:
|
||||
|
||||
```text
|
||||
┌─arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])─┐
|
||||
│ 0.75 │
|
||||
└────────────────────────────────────────-----──┘
|
||||
```
|
||||
|
||||
[Original article](https://clickhouse.tech/docs/en/query_language/functions/array_functions/) <!--hide-->
|
||||
|
Loading…
Reference in New Issue
Block a user