Add aggregate function categoricalInformationValue (#8117)

* Add categorical iv aggregate function with tests
This commit is contained in:
hcz 2019-12-12 21:28:28 +08:00 committed by Olga Khvostikova
parent b8857e1a09
commit c0028c3942
7 changed files with 377 additions and 2 deletions

View File

@ -0,0 +1,53 @@
#include <AggregateFunctions/AggregateFunctionCategoricalInformationValue.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
{
AggregateFunctionPtr createAggregateFunctionCategoricalIV(
const std::string & name,
const DataTypes & arguments,
const Array & params
)
{
assertNoParameters(name, params);
if (arguments.size() < 2)
throw Exception(
"Aggregate function " + name + " requires two or more arguments",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (auto & argument : arguments)
{
if (!WhichDataType(argument).isUInt8())
throw Exception(
"All the arguments of aggregate function " + name + " should be UInt8",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
return std::make_shared<AggregateFunctionCategoricalIV<>>(arguments, params);
}
}
void registerAggregateFunctionCategoricalIV(
AggregateFunctionFactory & factory
)
{
factory.registerFunction("categoricalInformationValue", createAggregateFunctionCategoricalIV);
}
}

View File

@ -0,0 +1,156 @@
#pragma once
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <ext/range.h>
namespace DB
{
template <typename T = UInt64>
class AggregateFunctionCategoricalIV final : public IAggregateFunctionHelper<AggregateFunctionCategoricalIV<T>>
{
private:
size_t category_count;
public:
AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_) :
IAggregateFunctionHelper<AggregateFunctionCategoricalIV<T>> {arguments_, params_},
category_count {arguments_.size() - 1}
{
// notice: argument types has been checked before
}
String getName() const override
{
return "categoricalInformationValue";
}
const char * getHeaderFilePath() const override
{
return __FILE__;
}
void create(AggregateDataPtr place) const override
{
memset(place, 0, sizeOfData());
}
void destroy(AggregateDataPtr) const noexcept override
{
// nothing
}
bool hasTrivialDestructor() const override
{
return true;
}
size_t sizeOfData() const override
{
return sizeof(T) * (category_count + 1) * 2;
}
size_t alignOfData() const override
{
return alignof(T);
}
void add(
AggregateDataPtr place,
const IColumn ** columns,
size_t row_num,
Arena *
) const override
{
auto y_col = static_cast<const ColumnUInt8 *>(columns[category_count]);
bool y = y_col->getData()[row_num];
for (size_t i : ext::range(0, category_count))
{
auto x_col = static_cast<const ColumnUInt8 *>(columns[i]);
bool x = x_col->getData()[row_num];
if (x)
reinterpret_cast<T *>(place)[i * 2 + size_t(y)] += 1;
}
reinterpret_cast<T *>(place)[category_count * 2 + size_t(y)] += 1;
}
void merge(
AggregateDataPtr place,
ConstAggregateDataPtr rhs,
Arena *
) const override
{
for (size_t i : ext::range(0, category_count + 1))
{
reinterpret_cast<T *>(place)[i * 2] += reinterpret_cast<const T *>(rhs)[i * 2];
reinterpret_cast<T *>(place)[i * 2 + 1] += reinterpret_cast<const T *>(rhs)[i * 2 + 1];
}
}
void serialize(
ConstAggregateDataPtr place,
WriteBuffer & buf
) const override
{
buf.write(place, sizeOfData());
}
void deserialize(
AggregateDataPtr place,
ReadBuffer & buf,
Arena *
) const override
{
buf.read(place, sizeOfData());
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(
std::make_shared<DataTypeNumber<Float64>>()
);
}
void insertResultInto(
ConstAggregateDataPtr place,
IColumn & to
) const override
{
auto & col = static_cast<ColumnArray &>(to);
auto & data_col = static_cast<ColumnFloat64 &>(col.getData());
auto & offset_col = static_cast<ColumnArray::ColumnOffsets &>(
col.getOffsetsColumn()
);
data_col.reserve(data_col.size() + category_count);
T sum_no = reinterpret_cast<const T *>(place)[category_count * 2];
T sum_yes = reinterpret_cast<const T *>(place)[category_count * 2 + 1];
Float64 rev_no = 1. / sum_no;
Float64 rev_yes = 1. / sum_yes;
for (size_t i : ext::range(0, category_count))
{
T no = reinterpret_cast<const T *>(place)[i * 2];
T yes = reinterpret_cast<const T *>(place)[i * 2 + 1];
data_col.insertValue((no * rev_no - yes * rev_yes) * (log(no * rev_no) - log(yes * rev_yes)));
}
offset_col.insertValue(data_col.size());
}
};
}

View File

@ -35,6 +35,7 @@ void registerAggregateFunctionMLMethod(AggregateFunctionFactory &);
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &);
void registerAggregateFunctionMoving(AggregateFunctionFactory &);
void registerAggregateFunctionCategoricalIV(AggregateFunctionFactory &);
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
@ -78,6 +79,7 @@ void registerAggregateFunctions()
registerAggregateFunctionEntropy(factory);
registerAggregateFunctionSimpleLinearRegression(factory);
registerAggregateFunctionMoving(factory);
registerAggregateFunctionCategoricalIV(factory);
}
{

View File

@ -0,0 +1,24 @@
<test>
<type>loop</type>
<preconditions>
<table_exists>test.hits</table_exists>
</preconditions>
<stop_conditions>
<all_of>
<total_time_ms>10000</total_time_ms>
</all_of>
<any_of>
<average_speed_not_changing_for_ms>5000</average_speed_not_changing_for_ms>
<total_time_ms>20000</total_time_ms>
</any_of>
</stop_conditions>
<main_metric>
<min_time/>
</main_metric>
<query>SELECT categoricalInformationValue(Age &lt; 15, IsMobile)</query>
<query>SELECT categoricalInformationValue(Age &lt; 15, Age &gt;= 15 and Age &lt; 30, Age &gt;= 30 and Age &lt; 45, Age &gt;= 45 and Age &lt; 60, Age &gt;= 60, IsMobile)</query>
</test>

View File

@ -0,0 +1,14 @@
[nan]
[nan]
[nan]
[0]
[0]
[nan]
[nan]
[inf]
[inf]
0.135155 0.135155
[0,0]
0.067578 0.047947 0.067578 0.047947
[0,0]
0.067578 0.047947 0.067578 0.047947

View File

@ -0,0 +1,116 @@
-- trivial
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin(arrayPopBack([(1, 0)])) as x
);
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(0, 0)]) as x
);
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(1, 0)]) as x
);
-- single category
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(1, 0), (1, 0), (1, 0), (1, 1), (1, 1)]) as x
);
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(0, 0), (0, 1), (1, 0), (1, 1)]) as x
);
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(0, 0), (0, 0), (1, 0), (1, 0)]) as x
);
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(0, 1), (0, 1), (1, 1), (1, 1)]) as x
);
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(0, 0), (0, 1), (1, 1), (1, 1)]) as x
);
SELECT
categoricalInformationValue(x.1, x.2)
FROM (
SELECT
arrayJoin([(0, 0), (0, 1), (1, 0), (1, 0)]) as x
);
SELECT
round(categoricalInformationValue(x.1, x.2)[1], 6),
round((2 / 2 - 2 / 3) * (log(2 / 2) - log(2 / 3)), 6)
FROM (
SELECT
arrayJoin([(0, 0), (1, 0), (1, 0), (1, 1), (1, 1)]) as x
);
-- multiple category
SELECT
categoricalInformationValue(x.1, x.2, x.3)
FROM (
SELECT
arrayJoin([(1, 0, 0), (1, 0, 0), (1, 0, 1), (0, 1, 0), (0, 1, 0), (0, 1, 1)]) as x
);
SELECT
round(categoricalInformationValue(x.1, x.2, x.3)[1], 6),
round(categoricalInformationValue(x.1, x.2, x.3)[2], 6),
round((2 / 4 - 1 / 3) * (log(2 / 4) - log(1 / 3)), 6),
round((2 / 4 - 2 / 3) * (log(2 / 4) - log(2 / 3)), 6)
FROM (
SELECT
arrayJoin([(1, 0, 0), (1, 0, 0), (1, 0, 1), (0, 1, 0), (0, 1, 0), (0, 1, 1), (0, 1, 1)]) as x
);
-- multiple category, larger data size
SELECT
categoricalInformationValue(x.1, x.2, x.3)
FROM (
SELECT
arrayJoin([(1, 0, 0), (1, 0, 0), (1, 0, 1), (0, 1, 0), (0, 1, 0), (0, 1, 1)]) as x
FROM
numbers(1000)
);
SELECT
round(categoricalInformationValue(x.1, x.2, x.3)[1], 6),
round(categoricalInformationValue(x.1, x.2, x.3)[2], 6),
round((2 / 4 - 1 / 3) * (log(2 / 4) - log(1 / 3)), 6),
round((2 / 4 - 2 / 3) * (log(2 / 4) - log(2 / 3)), 6)
FROM (
SELECT
arrayJoin([(1, 0, 0), (1, 0, 0), (1, 0, 1), (0, 1, 0), (0, 1, 0), (0, 1, 1), (0, 1, 1)]) as x
FROM
numbers(1000)
);

View File

@ -1007,6 +1007,16 @@ Calculates the value of `Σ((x - x̅)(y - y̅)) / n`.
Calculates the Pearson correlation coefficient: `Σ((x - x̅)(y - y̅)) / sqrt(Σ((x - x̅)^2) * Σ((y - y̅)^2))`.
## categoricalInformationValue
Calculates the value of `(P(tag = 1) - P(tag = 0))(log(P(tag = 1)) - log(P(tag = 0)))` for each category.
```sql
categoricalInformationValue(category1, category2, ..., tag)
```
The result indicates how a discrete (categorical) feature `[category1, category2, ...]` contribute to a learning model which predicting the value of `tag`.
## simpleLinearRegression
Performs simple (unidimensional) linear regression.