Aggregate function for entropy

This commit is contained in:
alexander kozhikhov 2019-02-02 17:27:43 +03:00
parent 0f577da5c2
commit 8c2726b77c
7 changed files with 274 additions and 2 deletions

View File

@ -0,0 +1,55 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionEntropy.h>
#include <AggregateFunctions/FactoryHelpers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
{
AggregateFunctionPtr createAggregateFunctionEntropy(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
if (argument_types.empty())
throw Exception("Incorrect number of arguments for aggregate function " + name,
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
WhichDataType which(argument_types[0]);
if (isNumber(argument_types[0]))
{
if (which.isUInt64())
{
return std::make_shared<AggregateFunctionEntropy<UInt64, UInt64>>();
}
else if (which.isInt64())
{
return std::make_shared<AggregateFunctionEntropy<Int64, Int64>>();
}
else if (which.isInt32())
{
return std::make_shared<AggregateFunctionEntropy<Int32, Int32>>();
}
else if (which.isUInt32())
{
return std::make_shared<AggregateFunctionEntropy<UInt32, UInt32>>();
}
}
return std::make_shared<AggregateFunctionEntropy<UInt128, String>>();
}
}
void registerAggregateFunctionEntropy(AggregateFunctionFactory & factory)
{
factory.registerFunction("entropy", createAggregateFunctionEntropy);
}
}

View File

@ -0,0 +1,161 @@
#pragma once
#include <AggregateFunctions/FactoryHelpers.h>
/// These must be exposed in header for the purpose of dynamic compilation.
#include <AggregateFunctions/QuantileReservoirSampler.h>
#include <AggregateFunctions/QuantileReservoirSamplerDeterministic.h>
#include <AggregateFunctions/QuantileExact.h>
#include <AggregateFunctions/QuantileExactWeighted.h>
#include <AggregateFunctions/QuantileTiming.h>
#include <AggregateFunctions/QuantileTDigest.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/QuantilesCommon.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <type_traits>
#include <cmath>
#include <AggregateFunctions/UniqVariadicHash.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
/** Calculates Shannon Entropy, using HashMap and computing empirical distribution function
*/
template <typename Value>
struct EntropyData
{
using Weight = UInt64;
using Map = HashMap <
Value, Weight,
HashCRC32<Value>,
HashTableGrower<4>,
HashTableAllocatorWithStackMemory<sizeof(std::pair<Value, Weight>) * (1 << 3)>
>;
Map map;
void add(const Value &x)
{
if (!isNaN(x))
++map[x];
}
void add(const Value &x, const Weight &weight)
{
if (!isNaN(x))
map[x] += weight;
}
void merge(const EntropyData &rhs)
{
for (const auto &pair : rhs.map)
map[pair.first] += pair.second;
}
void serialize(WriteBuffer &buf) const
{
map.write(buf);
}
void deserialize(ReadBuffer &buf)
{
typename Map::Reader reader(buf);
while (reader.next())
{
const auto &pair = reader.get();
map[pair.first] = pair.second;
}
}
Float64 get() const
{
Float64 ShannonEntropy = 0;
UInt64 TotalValue = 0;
for (const auto & pair : map)
{
TotalValue += pair.second;
}
Float64 cur_proba;
Float64 log2e = 1 / std::log(2);
for (const auto & pair : map)
{
cur_proba = Float64(pair.second) / TotalValue;
ShannonEntropy -= cur_proba * std::log(cur_proba) * log2e;
}
return ShannonEntropy;
}
};
template <typename Value, typename UserValue>
class AggregateFunctionEntropy final : public IAggregateFunctionDataHelper<EntropyData<Value>,
AggregateFunctionEntropy<Value, UserValue>>
{
public:
AggregateFunctionEntropy()
{}
String getName() const override { return "entropy"; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<Float64>>();
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if constexpr (!std::is_same_v<UInt128, Value>)
{
const auto &column = static_cast<const ColumnVector <UserValue> &>(*columns[0]);
this->data(place).add(column.getData()[row_num]);
}
else
{
this->data(place).add(UniqVariadicHash<true, false>::apply(1, columns, row_num));
}
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(const_cast<AggregateDataPtr>(place)).serialize(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
this->data(place).deserialize(buf);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
column.getData().push_back(this->data(place).get());
}
const char * getHeaderFilePath() const override { return __FILE__; }
};
}

View File

@ -19,7 +19,7 @@ namespace ErrorCodes
/** Calculates quantile by collecting all values into array
* and applying n-th element (introselect) algorithm for the resulting array.
*
* It use O(N) memory and it is very inefficient in case of high amount of identical values.
* It uses O(N) memory and it is very inefficient in case of high amount of identical values.
* But it is very CPU efficient for not large datasets.
*/
template <typename Value>

View File

@ -14,7 +14,7 @@ namespace ErrorCodes
/** Calculates quantile by counting number of occurrences for each value in a hash map.
*
* It use O(distinct(N)) memory. Can be naturally applied for values with weight.
* It uses O(distinct(N)) memory. Can be naturally applied for values with weight.
* In case of many identical values, it can be more efficient than QuantileExact even when weight is not used.
*/
template <typename Value>

View File

@ -27,6 +27,7 @@ void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory &);
void registerAggregateFunctionTopK(AggregateFunctionFactory &);
void registerAggregateFunctionsBitwise(AggregateFunctionFactory &);
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
@ -65,6 +66,7 @@ void registerAggregateFunctions()
registerAggregateFunctionsMaxIntersections(factory);
registerAggregateFunctionHistogram(factory);
registerAggregateFunctionRetention(factory);
registerAggregateFunctionEntropy(factory);
}
{

View File

@ -0,0 +1,5 @@
1
1
1
1
1

View File

@ -0,0 +1,49 @@
CREATE DATABASE IF NOT EXISTS test;
DROP TABLE IF EXISTS test.defaults;
CREATE TABLE IF NOT EXISTS test.defaults
(
vals String
) ENGINE = Memory;
insert into test.defaults values ('ba'), ('aa'), ('ba'), ('b'), ('ba'), ('aa');
select val < 1.5 and val > 1.459 from (select entropy(vals) as val from test.defaults);
CREATE DATABASE IF NOT EXISTS test;
DROP TABLE IF EXISTS test.defaults;
CREATE TABLE IF NOT EXISTS test.defaults
(
vals UInt64
) ENGINE = Memory;
insert into test.defaults values (0), (0), (1), (0), (0), (0), (1), (2), (3), (5), (3), (1), (1), (4), (5), (2)
select val < 2.4 and val > 2.3393 from (select entropy(vals) as val from test.defaults);
CREATE DATABASE IF NOT EXISTS test;
DROP TABLE IF EXISTS test.defaults;
CREATE TABLE IF NOT EXISTS test.defaults
(
vals UInt32
) ENGINE = Memory;
insert into test.defaults values (0), (0), (1), (0), (0), (0), (1), (2), (3), (5), (3), (1), (1), (4), (5), (2)
select val < 2.4 and val > 2.3393 from (select entropy(vals) as val from test.defaults);
CREATE DATABASE IF NOT EXISTS test;
DROP TABLE IF EXISTS test.defaults;
CREATE TABLE IF NOT EXISTS test.defaults
(
vals Int32
) ENGINE = Memory;
insert into test.defaults values (0), (0), (-1), (0), (0), (0), (-1), (2), (3), (5), (3), (-1), (-1), (4), (5), (2)
select val < 2.4 and val > 2.3393 from (select entropy(vals) as val from test.defaults);
CREATE DATABASE IF NOT EXISTS test;
DROP TABLE IF EXISTS test.defaults;
CREATE TABLE IF NOT EXISTS test.defaults
(
vals DateTime
) ENGINE = Memory;
insert into test.defaults values (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 24:00:00')), (toDateTime('2016-06-15 24:00:00')), (toDateTime('2016-06-15 24:00:00')), (toDateTime('2017-06-15 24:00:00')), (toDateTime('2017-06-15 24:00:00')), (toDateTime('2018-06-15 24:00:00')), (toDateTime('2018-06-15 24:00:00')), (toDateTime('2019-06-15 24:00:00'));
select val < 2.189 and val > 2.1886 from (select entropy(vals) as val from test.defaults);