mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 23:21:59 +00:00
fix bugs
This commit is contained in:
parent
71d46bb521
commit
f0f0768e81
@ -4,6 +4,7 @@
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeEnum.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
@ -112,12 +113,9 @@ public:
|
||||
throw Exception("All arguments for function " + getName() + " must be an array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
auto & nested_type = array_type->getNestedType();
|
||||
bool is_number = isNativeNumber(nested_type)
|
||||
if (!is_number)
|
||||
{
|
||||
if (!isNativeNumber(nested_type) && !isEnum(nested_type))
|
||||
throw Exception(
|
||||
getName() + " cannot process values of type " + nested_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
nested_types[i] = nested_type;
|
||||
}
|
||||
|
||||
|
@ -14,6 +14,10 @@ class ArrayAUCImpl
|
||||
{
|
||||
public:
|
||||
using ResultType = Float64;
|
||||
using LabelValueSet = std::set<Int16>;
|
||||
using LabelValueSets = std::vector<LabelValueSet>;
|
||||
|
||||
inline static const LabelValueSets expect_label_value_sets = {{0, 1}, {-1, 1}};
|
||||
|
||||
struct ScoreLabel
|
||||
{
|
||||
@ -27,23 +31,22 @@ public:
|
||||
// Labels values are either {0, 1} or {-1, 1}, and its type must be one of (Enum8, UInt8, Int8)
|
||||
if (!which.isUInt8() && !which.isEnum8() && !which.isInt8())
|
||||
{
|
||||
throw Exception(std::string(NameArrayAUC::name) + "lable type must be UInt8", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
throw Exception(
|
||||
std::string(NameArrayAUC::name) + "lable type must be UInt8, Enum8 or Int8", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
if (which.isEnum8())
|
||||
{
|
||||
auto type8 = checkAndGetDataType<DataTypeEnum8>(label_type.get());
|
||||
if (type8)
|
||||
if (!type8)
|
||||
throw Exception(std::string(NameArrayAUC::name) + "lable type not valid Enum8", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
std::set<Int8> valSet;
|
||||
LabelValueSet value_set;
|
||||
const auto & values = type8->getValues();
|
||||
for (const auto & value : values)
|
||||
{
|
||||
valSet.insert(value.second);
|
||||
}
|
||||
value_set.insert(value.second);
|
||||
|
||||
if (valSet != {0, 1} || valSet != {-1, 1})
|
||||
if (std::find(expect_label_value_sets.begin(), expect_label_value_sets.end(), value_set) == expect_label_value_sets.end())
|
||||
throw Exception(
|
||||
std::string(NameArrayAUC::name) + "lable values must be {0, 1} or {-1, 1}", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
@ -67,7 +70,7 @@ public:
|
||||
// Calculate positive and negative label number and restore scores and labels in vector
|
||||
size_t num_pos = 0;
|
||||
size_t num_neg = 0;
|
||||
std::set<Int16> labelValSet;
|
||||
LabelValueSet label_value_set;
|
||||
std::vector<ScoreLabel> pairs(score_len);
|
||||
for (size_t i = 0; i < score_len; ++i)
|
||||
{
|
||||
@ -78,11 +81,11 @@ public:
|
||||
else
|
||||
++num_neg;
|
||||
|
||||
labelValSet.insert(labels[i + label_offset]);
|
||||
label_value_set.insert(labels[i + label_offset]);
|
||||
}
|
||||
|
||||
// Label values must be {0, 1} or {-1, 1}
|
||||
if (labelValSet != {0, 1} && labelValSet != {-1, 1})
|
||||
if (std::find(expect_label_value_sets.begin(), expect_label_value_sets.end(), label_value_set) == expect_label_value_sets.end())
|
||||
throw Exception(
|
||||
std::string(NameArrayAUC::name) + "lable values must be {0, 1} or {-1, 1}", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
@ -119,6 +122,7 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/// auc(array_score, array_label) - Calculate AUC with array of score and label
|
||||
using FunctionArrayAUC = FunctionArrayScalarProduct<ArrayAUCImpl, NameArrayAUC>;
|
||||
|
||||
@ -126,6 +130,4 @@ void registerFunctionArrayAUC(FunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction<FunctionArrayAUC>();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
@ -1 +1,16 @@
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.75
|
||||
0.25
|
||||
0.25
|
||||
0.25
|
||||
0.25
|
||||
0.25
|
||||
0.125
|
||||
0.25
|
||||
|
@ -1 +1,16 @@
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1])
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], [0, 0, 1, 1]);
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([0, 0, 1, 1] as Array(Int8)));
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast([-1, -1, 1, 1] as Array(Int8)));
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = 0, 'true' = 1))));
|
||||
select arrayAUC([0.1, 0.4, 0.35, 0.8], cast(['false', 'false', 'true', 'true'] as Array(Enum8('false' = -1, 'true' = 1))));
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt8)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt16)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt32)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([10, 40, 35, 80] as Array(UInt64)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int8)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int16)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int32)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([-10, -40, -35, -80] as Array(Int64)), [0, 0, 1, 1]);
|
||||
select arrayAUC(cast([-0.1, -0.4, -0.35, -0.8] as Array(Float32)) , [0, 0, 1, 1]);
|
||||
select arrayAUC([0, 3, 5, 6, 7.5, 8], [1, 0, 1, 0, 0, 0]);
|
||||
select arrayAUC([0.1, 0.35, 0.4, 0.8], [1, 0, 1, 0]);
|
Loading…
Reference in New Issue
Block a user