add array function auc

This commit is contained in:
liyang 2020-01-17 21:52:02 +08:00
parent e2d8360a76
commit b914c4262d
6 changed files with 264 additions and 1 deletions

View File

@ -0,0 +1,157 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnNullable.h>
#include <Common/FieldVisitors.h>
#include <Common/memcmpSmall.h>
#include <Common/assert_cast.h>
#include <Interpreters/castColumn.h>
#include <Interpreters/Context.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
template <typename Method, typename Name>
class FunctionArrayScalarProduct : public IFunction
{
public:
static constexpr auto name = Name::name;
static FunctionPtr create(const Context & context) { return std::make_shared<FunctionArrayScalarProduct>(context); }
FunctionArrayScalarProduct(const Context & context_): context(context_) {}
private:
using ResultColumnType = ColumnVector<typename Method::ResultType>;
template <typename T>
bool executeNumber(Block & block, const ColumnNumbers & arguments, size_t result)
{
return executeNumberNumber<T, UInt8>(block, arguments, result)
|| executeNumberNumber<T, UInt16>(block, arguments, result)
|| executeNumberNumber<T, UInt32>(block, arguments, result)
|| executeNumberNumber<T, UInt64>(block, arguments, result)
|| executeNumberNumber<T, Int8>(block, arguments, result)
|| executeNumberNumber<T, Int16>(block, arguments, result)
|| executeNumberNumber<T, Int32>(block, arguments, result)
|| executeNumberNumber<T, Int64>(block, arguments, result)
|| executeNumberNumber<T, Float32>(block, arguments, result)
|| executeNumberNumber<T, Float64>(block, arguments, result);
}
template <typename T, typename U>
bool executeNumberNumber(Block & block, const ColumnNumbers & arguments, size_t result)
{
ColumnPtr col1 = block.getByPosition(arguments[0]).column->convertToFullColumnIfConst();
ColumnPtr col2 = block.getByPosition(arguments[1]).column->convertToFullColumnIfConst();
if (! col1 || ! col2)
return false;
const ColumnArray* col_array1 = checkAndGetColumn<ColumnArray>(col1.get());
const ColumnArray* col_array2 = checkAndGetColumn<ColumnArray>(col2.get());
if (! col_array1 || ! col_array2)
return false;
const ColumnVector<T> * col_nested1 = checkAndGetColumn<ColumnVector<T>>(col_array1->getData());
const ColumnVector<U> * col_nested2 = checkAndGetColumn<ColumnVector<U>>(col_array2->getData());
if (! col_nested1 || ! col_nested2)
return false;
auto col_res = ResultColumnType::create();
vector(col_nested1->getData(), col_array1->getOffsets(),
col_nested2->getData(), col_array2->getOffsets(), col_res->getData());
block.getByPosition(result).column = std::move(col_res);
return true;
}
template <typename T, typename U>
static void vector(const PaddedPODArray<T> & data1, const ColumnArray::Offsets & offsets1,
const PaddedPODArray<U> & data2, const ColumnArray::Offsets & offsets2,
PaddedPODArray<typename Method::ResultType> & result)
{
size_t size = offsets1.size();
result.resize(size);
ColumnArray::Offset current_offset1 = 0;
ColumnArray::Offset current_offset2 = 0;
for (size_t i = 0; i < size; ++i) {
size_t array1_size = offsets1[i] - current_offset1;
size_t array2_size = offsets2[i] - current_offset2;
result[i] = Method::apply(data1, current_offset1, array1_size, data2, current_offset2, array2_size);
current_offset1 = offsets1[i];
current_offset2 = offsets2[i];
}
}
public:
/// Get function name.
String getName() const override
{
return name;
}
size_t getNumberOfArguments() const override { return 2; }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
// Basic type check
std::vector<DataTypePtr> nested_types(2, nullptr);
for (size_t i = 0; i < getNumberOfArguments(); ++i)
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[i].get());
if (!array_type)
throw Exception("All argument for function " + getName() + " must be an array.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
auto & nested_type = array_type->getNestedType();
WhichDataType which(nested_type);
bool is_number = which.isNativeInt() || which.isNativeUInt() || which.isFloat();
if (! is_number)
{
throw Exception(getName() + " cannot process values of type " + nested_type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
nested_types[i] = nested_type;
}
// Detail type check in Method, then return ReturnType
return Method::getReturnType(nested_types[0], nested_types[1]);
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /* input_rows_count */) override
{
if (!(executeNumber<UInt8>(block, arguments, result)
|| executeNumber<UInt16>(block, arguments, result)
|| executeNumber<UInt32>(block, arguments, result)
|| executeNumber<UInt64>(block, arguments, result)
|| executeNumber<Int8>(block, arguments, result)
|| executeNumber<Int16>(block, arguments, result)
|| executeNumber<Int32>(block, arguments, result)
|| executeNumber<Int64>(block, arguments, result)
|| executeNumber<Float32>(block, arguments, result)
|| executeNumber<Float64>(block, arguments, result)))
throw Exception{"Illegal column " + block.getByPosition(arguments[0]).column->getName()
+ " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN};
}
private:
const Context & context;
};
}

View File

@ -0,0 +1,97 @@
#include <vector>
#include <algorithm>
#include <Functions/FunctionFactory.h>
#include "arrayScalarProduct.h"
namespace DB
{
struct NameAUC { static constexpr auto name = "auc"; };
class AUCImpl
{
public:
using ResultType = Float64;
struct ScoreLabel
{
ResultType score;
UInt8 label;
};
static DataTypePtr getReturnType(const DataTypePtr & /* nested_type1 */, const DataTypePtr & nested_type2)
{
WhichDataType which2(nested_type2);
if (! which2.isUInt8()) {
throw Exception(std::string(NameAUC::name) + "lable type must be UInt8",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
return std::make_shared<DataTypeNumber<ResultType>>();
}
template <typename T, typename U>
static ResultType apply(const PaddedPODArray<T> & scores, ColumnArray::Offset score_offset, size_t score_len,
const PaddedPODArray<U> & labels, ColumnArray::Offset label_offset, size_t label_len)
{
if (score_len != label_len)
throw Exception{"Unmatched length of arrays in " + std::string(NameAUC::name),
ErrorCodes::LOGICAL_ERROR};
if (score_len == 0)
return {};
// Order pairs of score and lable by score ascending
size_t num_pos = 0;
size_t num_neg = 0;
std::vector<ScoreLabel> pairs(score_len);
for (size_t i = 0; i < score_len; ++i)
{
pairs[i].score = scores[i + score_offset];
pairs[i].label = (labels[i + label_offset] ? 1 : 0);
if (pairs[i].label)
++num_pos;
else
++num_neg;
}
std::sort(pairs.begin(), pairs.end(),
[](const auto & lhs, const auto & rhs) {return lhs.score < rhs.score; });
// Calculate AUC
size_t curr_cnt = 0;
size_t curr_pos_cnt = 0;
size_t curr_sum = 0;
ResultType last_score = -1;
ResultType rank_sum = 0;
for (size_t i = 0; i < pairs.size(); ++i)
{
if (pairs[i].score == last_score)
{
curr_sum += i + 1;
++curr_cnt;
if (pairs[i].label)
++curr_pos_cnt;
}
else
{
if (i > 0)
rank_sum += ResultType(curr_sum * curr_pos_cnt) / curr_cnt;
curr_sum = i + 1;
curr_cnt = 1;
curr_pos_cnt = pairs[i].label ? 1 : 0;
}
last_score = pairs[i].score;
}
rank_sum += ResultType(curr_sum * curr_pos_cnt) / curr_cnt;
return (rank_sum - num_pos*(num_pos+1)/2)/(num_pos * num_neg);
}
};
/// auc(array_score, array_label) - Calculate AUC with array of score and label
using FunctionAUC = FunctionArrayScalarProduct<AUCImpl, NameAUC>;
void registerFunctionAUC(FunctionFactory & factory)
{
factory.registerFunction<FunctionAUC>();
}
}

View File

@ -33,7 +33,7 @@ void registerFunctionArrayDistinct(FunctionFactory & factory);
void registerFunctionArrayFlatten(FunctionFactory & factory);
void registerFunctionArrayWithConstant(FunctionFactory & factory);
void registerFunctionArrayZip(FunctionFactory & factory);
void registerFunctionAUC(FunctionFactory &);
void registerFunctionsArray(FunctionFactory & factory)
{
@ -67,6 +67,7 @@ void registerFunctionsArray(FunctionFactory & factory)
registerFunctionArrayFlatten(factory);
registerFunctionArrayWithConstant(factory);
registerFunctionArrayZip(factory);
registerFunctionAUC(factory);
}
}

View File

@ -0,0 +1 @@
0.75

View File

@ -0,0 +1 @@
select select auc([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])

View File

@ -896,5 +896,11 @@ Result:
│ [('a','d'),('b','e'),('c','f')] │
└────────────────────────────────────────────┘
```
## auc(arr_scores, arr_labels)
Returns AUC(Area Under the Curve, which is a concept in machine learning, see more details: https://developers.google.com/machine-learning/crash-course/classification/roc-and-auc);
`arr_scores` represents scores prediction model gives, while `arr_labels` represents labels of samples, usually 1 for positive sample and 0 for negtive sample.
[Original article](https://clickhouse.yandex/docs/en/query_language/functions/array_functions/) <!--hide-->