mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
Add aggregate function categoricalInformationValue (#8117)
* Add categorical iv aggregate function with tests
This commit is contained in:
parent
b8857e1a09
commit
c0028c3942
@ -0,0 +1,53 @@
|
||||
#include <AggregateFunctions/AggregateFunctionCategoricalInformationValue.h>
|
||||
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionCategoricalIV(
|
||||
const std::string & name,
|
||||
const DataTypes & arguments,
|
||||
const Array & params
|
||||
)
|
||||
{
|
||||
assertNoParameters(name, params);
|
||||
|
||||
if (arguments.size() < 2)
|
||||
throw Exception(
|
||||
"Aggregate function " + name + " requires two or more arguments",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
for (auto & argument : arguments)
|
||||
{
|
||||
if (!WhichDataType(argument).isUInt8())
|
||||
throw Exception(
|
||||
"All the arguments of aggregate function " + name + " should be UInt8",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
return std::make_shared<AggregateFunctionCategoricalIV<>>(arguments, params);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionCategoricalIV(
|
||||
AggregateFunctionFactory & factory
|
||||
)
|
||||
{
|
||||
factory.registerFunction("categoricalInformationValue", createAggregateFunctionCategoricalIV);
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,156 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
#include <ext/range.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
template <typename T = UInt64>
|
||||
class AggregateFunctionCategoricalIV final : public IAggregateFunctionHelper<AggregateFunctionCategoricalIV<T>>
|
||||
{
|
||||
private:
|
||||
size_t category_count;
|
||||
|
||||
public:
|
||||
AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_) :
|
||||
IAggregateFunctionHelper<AggregateFunctionCategoricalIV<T>> {arguments_, params_},
|
||||
category_count {arguments_.size() - 1}
|
||||
{
|
||||
// notice: argument types has been checked before
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "categoricalInformationValue";
|
||||
}
|
||||
|
||||
const char * getHeaderFilePath() const override
|
||||
{
|
||||
return __FILE__;
|
||||
}
|
||||
|
||||
void create(AggregateDataPtr place) const override
|
||||
{
|
||||
memset(place, 0, sizeOfData());
|
||||
}
|
||||
|
||||
void destroy(AggregateDataPtr) const noexcept override
|
||||
{
|
||||
// nothing
|
||||
}
|
||||
|
||||
bool hasTrivialDestructor() const override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t sizeOfData() const override
|
||||
{
|
||||
return sizeof(T) * (category_count + 1) * 2;
|
||||
}
|
||||
|
||||
size_t alignOfData() const override
|
||||
{
|
||||
return alignof(T);
|
||||
}
|
||||
|
||||
void add(
|
||||
AggregateDataPtr place,
|
||||
const IColumn ** columns,
|
||||
size_t row_num,
|
||||
Arena *
|
||||
) const override
|
||||
{
|
||||
auto y_col = static_cast<const ColumnUInt8 *>(columns[category_count]);
|
||||
bool y = y_col->getData()[row_num];
|
||||
|
||||
for (size_t i : ext::range(0, category_count))
|
||||
{
|
||||
auto x_col = static_cast<const ColumnUInt8 *>(columns[i]);
|
||||
bool x = x_col->getData()[row_num];
|
||||
|
||||
if (x)
|
||||
reinterpret_cast<T *>(place)[i * 2 + size_t(y)] += 1;
|
||||
}
|
||||
|
||||
reinterpret_cast<T *>(place)[category_count * 2 + size_t(y)] += 1;
|
||||
}
|
||||
|
||||
void merge(
|
||||
AggregateDataPtr place,
|
||||
ConstAggregateDataPtr rhs,
|
||||
Arena *
|
||||
) const override
|
||||
{
|
||||
for (size_t i : ext::range(0, category_count + 1))
|
||||
{
|
||||
reinterpret_cast<T *>(place)[i * 2] += reinterpret_cast<const T *>(rhs)[i * 2];
|
||||
reinterpret_cast<T *>(place)[i * 2 + 1] += reinterpret_cast<const T *>(rhs)[i * 2 + 1];
|
||||
}
|
||||
}
|
||||
|
||||
void serialize(
|
||||
ConstAggregateDataPtr place,
|
||||
WriteBuffer & buf
|
||||
) const override
|
||||
{
|
||||
buf.write(place, sizeOfData());
|
||||
}
|
||||
|
||||
void deserialize(
|
||||
AggregateDataPtr place,
|
||||
ReadBuffer & buf,
|
||||
Arena *
|
||||
) const override
|
||||
{
|
||||
buf.read(place, sizeOfData());
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(
|
||||
std::make_shared<DataTypeNumber<Float64>>()
|
||||
);
|
||||
}
|
||||
|
||||
void insertResultInto(
|
||||
ConstAggregateDataPtr place,
|
||||
IColumn & to
|
||||
) const override
|
||||
{
|
||||
auto & col = static_cast<ColumnArray &>(to);
|
||||
auto & data_col = static_cast<ColumnFloat64 &>(col.getData());
|
||||
auto & offset_col = static_cast<ColumnArray::ColumnOffsets &>(
|
||||
col.getOffsetsColumn()
|
||||
);
|
||||
|
||||
data_col.reserve(data_col.size() + category_count);
|
||||
|
||||
T sum_no = reinterpret_cast<const T *>(place)[category_count * 2];
|
||||
T sum_yes = reinterpret_cast<const T *>(place)[category_count * 2 + 1];
|
||||
|
||||
Float64 rev_no = 1. / sum_no;
|
||||
Float64 rev_yes = 1. / sum_yes;
|
||||
|
||||
for (size_t i : ext::range(0, category_count))
|
||||
{
|
||||
T no = reinterpret_cast<const T *>(place)[i * 2];
|
||||
T yes = reinterpret_cast<const T *>(place)[i * 2 + 1];
|
||||
|
||||
data_col.insertValue((no * rev_no - yes * rev_yes) * (log(no * rev_no) - log(yes * rev_yes)));
|
||||
}
|
||||
|
||||
offset_col.insertValue(data_col.size());
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -35,6 +35,7 @@ void registerAggregateFunctionMLMethod(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionMoving(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionCategoricalIV(AggregateFunctionFactory &);
|
||||
|
||||
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
||||
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
|
||||
@ -78,6 +79,7 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionEntropy(factory);
|
||||
registerAggregateFunctionSimpleLinearRegression(factory);
|
||||
registerAggregateFunctionMoving(factory);
|
||||
registerAggregateFunctionCategoricalIV(factory);
|
||||
}
|
||||
|
||||
{
|
||||
|
24
dbms/tests/performance/information_value.xml
Normal file
24
dbms/tests/performance/information_value.xml
Normal file
@ -0,0 +1,24 @@
|
||||
<test>
|
||||
<type>loop</type>
|
||||
|
||||
<preconditions>
|
||||
<table_exists>test.hits</table_exists>
|
||||
</preconditions>
|
||||
|
||||
<stop_conditions>
|
||||
<all_of>
|
||||
<total_time_ms>10000</total_time_ms>
|
||||
</all_of>
|
||||
<any_of>
|
||||
<average_speed_not_changing_for_ms>5000</average_speed_not_changing_for_ms>
|
||||
<total_time_ms>20000</total_time_ms>
|
||||
</any_of>
|
||||
</stop_conditions>
|
||||
|
||||
<main_metric>
|
||||
<min_time/>
|
||||
</main_metric>
|
||||
|
||||
<query>SELECT categoricalInformationValue(Age < 15, IsMobile)</query>
|
||||
<query>SELECT categoricalInformationValue(Age < 15, Age >= 15 and Age < 30, Age >= 30 and Age < 45, Age >= 45 and Age < 60, Age >= 60, IsMobile)</query>
|
||||
</test>
|
@ -0,0 +1,14 @@
|
||||
[nan]
|
||||
[nan]
|
||||
[nan]
|
||||
[0]
|
||||
[0]
|
||||
[nan]
|
||||
[nan]
|
||||
[inf]
|
||||
[inf]
|
||||
0.135155 0.135155
|
||||
[0,0]
|
||||
0.067578 0.047947 0.067578 0.047947
|
||||
[0,0]
|
||||
0.067578 0.047947 0.067578 0.047947
|
116
dbms/tests/queries/0_stateless/01043_categorical_iv.sql
Normal file
116
dbms/tests/queries/0_stateless/01043_categorical_iv.sql
Normal file
@ -0,0 +1,116 @@
|
||||
-- trivial
|
||||
|
||||
SELECT
|
||||
categoricalInformationValue(x.1, x.2)
|
||||
FROM (
|
||||
SELECT
|
||||
arrayJoin(arrayPopBack([(1, 0)])) as x
|
||||
);
|
||||
|
||||
SELECT
|
||||
categoricalInformationValue(x.1, x.2)
|
||||
FROM (
|
||||
SELECT
|
||||
arrayJoin([(0, 0)]) as x
|
||||
);
|
||||
|
||||
SELECT
|
||||
categoricalInformationValue(x.1, x.2)
|
||||
FROM (
|
||||
SELECT
|
||||
arrayJoin([(1, 0)]) as x
|
||||
);
|
||||
|
||||
-- single category
|
||||
|
||||
SELECT
|
||||
categoricalInformationValue(x.1, x.2)
|
||||
FROM (
|
||||
SELECT
|
||||
arrayJoin([(1, 0), (1, 0), (1, 0), (1, 1), (1, 1)]) as x
|
||||
);
|
||||
|
||||
SELECT
|
||||
categoricalInformationValue(x.1, x.2)
|
||||
FROM (
|
||||
SELECT
|
||||
arrayJoin([(0, 0), (0, 1), (1, 0), (1, 1)]) as x
|
||||
);
|
||||
|
||||
SELECT
|
||||
categoricalInformationValue(x.1, x.2)
|
||||
FROM (
|
||||
SELECT
|
||||
arrayJoin([(0, 0), (0, 0), (1, 0), (1, 0)]) as x
|
||||
);
|
||||
|
||||
SELECT
|
||||
categoricalInformationValue(x.1, x.2)
|
||||
FROM (
|
||||
SELECT
|
||||
arrayJoin([(0, 1), (0, 1), (1, 1), (1, 1)]) as x
|
||||
);
|
||||
|
||||
SELECT
|
||||
categoricalInformationValue(x.1, x.2)
|
||||
FROM (
|
||||
SELECT
|
||||
arrayJoin([(0, 0), (0, 1), (1, 1), (1, 1)]) as x
|
||||
);
|
||||
|
||||
SELECT
|
||||
categoricalInformationValue(x.1, x.2)
|
||||
FROM (
|
||||
SELECT
|
||||
arrayJoin([(0, 0), (0, 1), (1, 0), (1, 0)]) as x
|
||||
);
|
||||
|
||||
SELECT
|
||||
round(categoricalInformationValue(x.1, x.2)[1], 6),
|
||||
round((2 / 2 - 2 / 3) * (log(2 / 2) - log(2 / 3)), 6)
|
||||
FROM (
|
||||
SELECT
|
||||
arrayJoin([(0, 0), (1, 0), (1, 0), (1, 1), (1, 1)]) as x
|
||||
);
|
||||
|
||||
-- multiple category
|
||||
|
||||
SELECT
|
||||
categoricalInformationValue(x.1, x.2, x.3)
|
||||
FROM (
|
||||
SELECT
|
||||
arrayJoin([(1, 0, 0), (1, 0, 0), (1, 0, 1), (0, 1, 0), (0, 1, 0), (0, 1, 1)]) as x
|
||||
);
|
||||
|
||||
SELECT
|
||||
round(categoricalInformationValue(x.1, x.2, x.3)[1], 6),
|
||||
round(categoricalInformationValue(x.1, x.2, x.3)[2], 6),
|
||||
round((2 / 4 - 1 / 3) * (log(2 / 4) - log(1 / 3)), 6),
|
||||
round((2 / 4 - 2 / 3) * (log(2 / 4) - log(2 / 3)), 6)
|
||||
FROM (
|
||||
SELECT
|
||||
arrayJoin([(1, 0, 0), (1, 0, 0), (1, 0, 1), (0, 1, 0), (0, 1, 0), (0, 1, 1), (0, 1, 1)]) as x
|
||||
);
|
||||
|
||||
-- multiple category, larger data size
|
||||
|
||||
SELECT
|
||||
categoricalInformationValue(x.1, x.2, x.3)
|
||||
FROM (
|
||||
SELECT
|
||||
arrayJoin([(1, 0, 0), (1, 0, 0), (1, 0, 1), (0, 1, 0), (0, 1, 0), (0, 1, 1)]) as x
|
||||
FROM
|
||||
numbers(1000)
|
||||
);
|
||||
|
||||
SELECT
|
||||
round(categoricalInformationValue(x.1, x.2, x.3)[1], 6),
|
||||
round(categoricalInformationValue(x.1, x.2, x.3)[2], 6),
|
||||
round((2 / 4 - 1 / 3) * (log(2 / 4) - log(1 / 3)), 6),
|
||||
round((2 / 4 - 2 / 3) * (log(2 / 4) - log(2 / 3)), 6)
|
||||
FROM (
|
||||
SELECT
|
||||
arrayJoin([(1, 0, 0), (1, 0, 0), (1, 0, 1), (0, 1, 0), (0, 1, 0), (0, 1, 1), (0, 1, 1)]) as x
|
||||
FROM
|
||||
numbers(1000)
|
||||
);
|
@ -858,7 +858,7 @@ Don't use this function for calculating timings. There is a more suitable functi
|
||||
|
||||
## quantileTiming {#agg_function-quantiletiming}
|
||||
|
||||
Computes the quantile of the specified level with determined precision. The function is intended for calculating page loading time quantiles in milliseconds.
|
||||
Computes the quantile of the specified level with determined precision. The function is intended for calculating page loading time quantiles in milliseconds.
|
||||
|
||||
```sql
|
||||
quantileTiming(level)(expr)
|
||||
@ -868,7 +868,7 @@ quantileTiming(level)(expr)
|
||||
|
||||
- `level` — Quantile level. Range: [0, 1].
|
||||
- `expr` — [Expression](../syntax.md#syntax-expressions) returning a [Float*](../../data_types/float.md)-type number. The function expects input values in unix timestamp format in milliseconds, but it doesn't validate format.
|
||||
|
||||
|
||||
- If negative values are passed to the function, the behavior is undefined.
|
||||
- If the value is greater than 30,000 (a page loading time of more than 30 seconds), it is assumed to be 30,000.
|
||||
|
||||
@ -1007,6 +1007,16 @@ Calculates the value of `Σ((x - x̅)(y - y̅)) / n`.
|
||||
|
||||
Calculates the Pearson correlation coefficient: `Σ((x - x̅)(y - y̅)) / sqrt(Σ((x - x̅)^2) * Σ((y - y̅)^2))`.
|
||||
|
||||
## categoricalInformationValue
|
||||
|
||||
Calculates the value of `(P(tag = 1) - P(tag = 0))(log(P(tag = 1)) - log(P(tag = 0)))` for each category.
|
||||
|
||||
```sql
|
||||
categoricalInformationValue(category1, category2, ..., tag)
|
||||
```
|
||||
|
||||
The result indicates how a discrete (categorical) feature `[category1, category2, ...]` contribute to a learning model which predicting the value of `tag`.
|
||||
|
||||
## simpleLinearRegression
|
||||
|
||||
Performs simple (unidimensional) linear regression.
|
||||
|
Loading…
Reference in New Issue
Block a user