Add aggregate function categoricalInformationValue (#8117)

* Add categorical iv aggregate function with tests
This commit is contained in:
hcz 2019-12-12 21:28:28 +08:00 committed by Olga Khvostikova
parent b8857e1a09
commit c0028c3942
7 changed files with 377 additions and 2 deletions

View File

@ -0,0 +1,53 @@
#include <AggregateFunctions/AggregateFunctionCategoricalInformationValue.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
{
AggregateFunctionPtr createAggregateFunctionCategoricalIV(
const std::string & name,
const DataTypes & arguments,
const Array & params
)
{
assertNoParameters(name, params);
if (arguments.size() < 2)
throw Exception(
"Aggregate function " + name + " requires two or more arguments",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (auto & argument : arguments)
{
if (!WhichDataType(argument).isUInt8())
throw Exception(
"All the arguments of aggregate function " + name + " should be UInt8",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
return std::make_shared<AggregateFunctionCategoricalIV<>>(arguments, params);
}
}
void registerAggregateFunctionCategoricalIV(
AggregateFunctionFactory & factory
)
{
factory.registerFunction("categoricalInformationValue", createAggregateFunctionCategoricalIV);
}
}

View File

@ -0,0 +1,156 @@
#pragma once
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <ext/range.h>
namespace DB
{
template <typename T = UInt64>
class AggregateFunctionCategoricalIV final : public IAggregateFunctionHelper<AggregateFunctionCategoricalIV<T>>
{
private:
size_t category_count;
public:
AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_) :
IAggregateFunctionHelper<AggregateFunctionCategoricalIV<T>> {arguments_, params_},
category_count {arguments_.size() - 1}
{
// notice: argument types has been checked before
}
String getName() const override
{
return "categoricalInformationValue";
}
const char * getHeaderFilePath() const override
{
return __FILE__;
}
void create(AggregateDataPtr place) const override
{
memset(place, 0, sizeOfData());
}
void destroy(AggregateDataPtr) const noexcept override
{
// nothing
}
bool hasTrivialDestructor() const override
{
return true;
}
size_t sizeOfData() const override
{
return sizeof(T) * (category_count + 1) * 2;
}
size_t alignOfData() const override
{
return alignof(T);
}
void add(
AggregateDataPtr place,
const IColumn ** columns,
size_t row_num,
Arena *
) const override
{
auto y_col = static_cast<const ColumnUInt8 *>(columns[category_count]);
bool y = y_col->getData()[row_num];
for (size_t i : ext::range(0, category_count))
{
auto x_col = static_cast<const ColumnUInt8 *>(columns[i]);
bool x = x_col->getData()[row_num];
if (x)
reinterpret_cast<T *>(place)[i * 2 + size_t(y)] += 1;
}
reinterpret_cast<T *>(place)[category_count * 2 + size_t(y)] += 1;
}
void merge(
AggregateDataPtr place,
ConstAggregateDataPtr rhs,
Arena *
) const override
{
for (size_t i : ext::range(0, category_count + 1))
{
reinterpret_cast<T *>(place)[i * 2] += reinterpret_cast<const T *>(rhs)[i * 2];
reinterpret_cast<T *>(place)[i * 2 + 1] += reinterpret_cast<const T *>(rhs)[i * 2 + 1];
}
}
void serialize(
ConstAggregateDataPtr place,
WriteBuffer & buf
) const override
{
buf.write(place, sizeOfData());
}
void deserialize(
AggregateDataPtr place,
ReadBuffer & buf,
Arena *
) const override
{
buf.read(place, sizeOfData());
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(
std::make_shared<DataTypeNumber<Float64>>()
);
}
void insertResultInto(
ConstAggregateDataPtr place,
IColumn & to
) const override
{
auto & col = static_cast<ColumnArray &>(to);
auto & data_col = static_cast<ColumnFloat64 &>(col.getData());
auto & offset_col = static_cast<ColumnArray::ColumnOffsets &>(
col.getOffsetsColumn()
);
data_col.reserve(data_col.size() + category_count);
T sum_no = reinterpret_cast<const T *>(place)[category_count * 2];
T sum_yes = reinterpret_cast<const T *>(place)[category_count * 2 + 1];
Float64 rev_no = 1. / sum_no;
Float64 rev_yes = 1. / sum_yes;
for (size_t i : ext::range(0, category_count))
{
T no = reinterpret_cast<const T *>(place)[i * 2];
T yes = reinterpret_cast<const T *>(place)[i * 2 + 1];
data_col.insertValue((no * rev_no - yes * rev_yes) * (log(no * rev_no) - log(yes * rev_yes)));
}
offset_col.insertValue(data_col.size());
}
};
}

View File

@ -35,6 +35,7 @@ void registerAggregateFunctionMLMethod(AggregateFunctionFactory &);
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &);
void registerAggregateFunctionMoving(AggregateFunctionFactory &);
void registerAggregateFunctionCategoricalIV(AggregateFunctionFactory &);
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
@ -78,6 +79,7 @@ void registerAggregateFunctions()
registerAggregateFunctionEntropy(factory);
registerAggregateFunctionSimpleLinearRegression(factory);
registerAggregateFunctionMoving(factory);
registerAggregateFunctionCategoricalIV(factory);
}
{

View File

@ -0,0 +1,24 @@
<test>
<type>loop</type>
<preconditions>
<table_exists>test.hits</table_exists>
</preconditions>
<stop_conditions>
<all_of>
<total_time_ms>10000</total_time_ms>
</all_of>
<any_of>
<average_speed_not_changing_for_ms>5000</average_speed_not_changing_for_ms>
<total_time_ms>20000</total_time_ms>
</any_of>
</stop_conditions>
<main_metric>
<min_time/>
</main_metric>
<query>SELECT categoricalInformationValue(Age &lt; 15, IsMobile)</query>
<query>SELECT categoricalInformationValue(Age &lt; 15, Age &gt;= 15 and Age &lt; 30, Age &gt;= 30 and Age &lt; 45, Age &gt;= 45 and Age &lt; 60, Age &gt;= 60, IsMobile)</query>
</test>

View File

@ -0,0 +1,14 @@
[nan]
[nan]
[nan]
[0]
[0]
[nan]
[nan]
[inf]
[inf]
0.135155 0.135155
[0,0]
0.067578 0.047947 0.067578 0.047947
[0,0]
0.067578 0.047947 0.067578 0.047947

View File

@ -0,0 +1,116 @@
-- trivial
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin(arrayPopBack([(1, 0)])) as x
);
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(0, 0)]) as x
);
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(1, 0)]) as x
);
-- single category
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(1, 0), (1, 0), (1, 0), (1, 1), (1, 1)]) as x
);
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(0, 0), (0, 1), (1, 0), (1, 1)]) as x
);
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(0, 0), (0, 0), (1, 0), (1, 0)]) as x
);
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(0, 1), (0, 1), (1, 1), (1, 1)]) as x
);
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(0, 0), (0, 1), (1, 1), (1, 1)]) as x
);
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(0, 0), (0, 1), (1, 0), (1, 0)]) as x
);
SELECT
round(categoricalInformationValue(x.1, x.2)[1], 6),
round((2 / 2 - 2 / 3) * (log(2 / 2) - log(2 / 3)), 6)
FROM (
SELECT
arrayJoin([(0, 0), (1, 0), (1, 0), (1, 1), (1, 1)]) as x
);
-- multiple category
SELECT
categoricalInformationValue(x.1, x.2, x.3)
FROM (
SELECT
arrayJoin([(1, 0, 0), (1, 0, 0), (1, 0, 1), (0, 1, 0), (0, 1, 0), (0, 1, 1)]) as x
);
SELECT
round(categoricalInformationValue(x.1, x.2, x.3)[1], 6),
round(categoricalInformationValue(x.1, x.2, x.3)[2], 6),
round((2 / 4 - 1 / 3) * (log(2 / 4) - log(1 / 3)), 6),
round((2 / 4 - 2 / 3) * (log(2 / 4) - log(2 / 3)), 6)
FROM (
SELECT
arrayJoin([(1, 0, 0), (1, 0, 0), (1, 0, 1), (0, 1, 0), (0, 1, 0), (0, 1, 1), (0, 1, 1)]) as x
);
-- multiple category, larger data size
SELECT
categoricalInformationValue(x.1, x.2, x.3)
FROM (
SELECT
arrayJoin([(1, 0, 0), (1, 0, 0), (1, 0, 1), (0, 1, 0), (0, 1, 0), (0, 1, 1)]) as x
FROM
numbers(1000)
);
SELECT
round(categoricalInformationValue(x.1, x.2, x.3)[1], 6),
round(categoricalInformationValue(x.1, x.2, x.3)[2], 6),
round((2 / 4 - 1 / 3) * (log(2 / 4) - log(1 / 3)), 6),
round((2 / 4 - 2 / 3) * (log(2 / 4) - log(2 / 3)), 6)
FROM (
SELECT
arrayJoin([(1, 0, 0), (1, 0, 0), (1, 0, 1), (0, 1, 0), (0, 1, 0), (0, 1, 1), (0, 1, 1)]) as x
FROM
numbers(1000)
);

View File

@ -858,7 +858,7 @@ Don't use this function for calculating timings. There is a more suitable functi
## quantileTiming {#agg_function-quantiletiming}
Computes the quantile of the specified level with determined precision. The function is intended for calculating page loading time quantiles in milliseconds.
Computes the quantile of the specified level with determined precision. The function is intended for calculating page loading time quantiles in milliseconds.
```sql
quantileTiming(level)(expr)
@ -868,7 +868,7 @@ quantileTiming(level)(expr)
- `level` — Quantile level. Range: [0, 1].
- `expr` — [Expression](../syntax.md#syntax-expressions) returning a [Float*](../../data_types/float.md)-type number. The function expects input values in unix timestamp format in milliseconds, but it doesn't validate format.
- If negative values are passed to the function, the behavior is undefined.
- If the value is greater than 30,000 (a page loading time of more than 30 seconds), it is assumed to be 30,000.
@ -1007,6 +1007,16 @@ Calculates the value of `Σ((x - x̅)(y - y̅)) / n`.
Calculates the Pearson correlation coefficient: `Σ((x - x̅)(y - y̅)) / sqrt(Σ((x - x̅)^2) * Σ((y - y̅)^2))`.
## categoricalInformationValue
Calculates the value of `(P(tag = 1) - P(tag = 0))(log(P(tag = 1)) - log(P(tag = 0)))` for each category.
```sql
categoricalInformationValue(category1, category2, ..., tag)
```
The result indicates how a discrete (categorical) feature `[category1, category2, ...]` contribute to a learning model which predicting the value of `tag`.
## simpleLinearRegression
Performs simple (unidimensional) linear regression.