mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 23:21:59 +00:00
Aggregate function for entropy
This commit is contained in:
parent
0f577da5c2
commit
8c2726b77c
55
dbms/src/AggregateFunctions/AggregateFunctionEntropy.cpp
Normal file
55
dbms/src/AggregateFunctions/AggregateFunctionEntropy.cpp
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||||
|
#include <AggregateFunctions/AggregateFunctionEntropy.h>
|
||||||
|
#include <AggregateFunctions/FactoryHelpers.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||||
|
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
|
||||||
|
AggregateFunctionPtr createAggregateFunctionEntropy(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||||
|
{
|
||||||
|
assertNoParameters(name, parameters);
|
||||||
|
if (argument_types.empty())
|
||||||
|
throw Exception("Incorrect number of arguments for aggregate function " + name,
|
||||||
|
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||||
|
|
||||||
|
WhichDataType which(argument_types[0]);
|
||||||
|
if (isNumber(argument_types[0]))
|
||||||
|
{
|
||||||
|
if (which.isUInt64())
|
||||||
|
{
|
||||||
|
return std::make_shared<AggregateFunctionEntropy<UInt64, UInt64>>();
|
||||||
|
}
|
||||||
|
else if (which.isInt64())
|
||||||
|
{
|
||||||
|
return std::make_shared<AggregateFunctionEntropy<Int64, Int64>>();
|
||||||
|
}
|
||||||
|
else if (which.isInt32())
|
||||||
|
{
|
||||||
|
return std::make_shared<AggregateFunctionEntropy<Int32, Int32>>();
|
||||||
|
}
|
||||||
|
else if (which.isUInt32())
|
||||||
|
{
|
||||||
|
return std::make_shared<AggregateFunctionEntropy<UInt32, UInt32>>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<AggregateFunctionEntropy<UInt128, String>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void registerAggregateFunctionEntropy(AggregateFunctionFactory & factory)
|
||||||
|
{
|
||||||
|
factory.registerFunction("entropy", createAggregateFunctionEntropy);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
161
dbms/src/AggregateFunctions/AggregateFunctionEntropy.h
Normal file
161
dbms/src/AggregateFunctions/AggregateFunctionEntropy.h
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <AggregateFunctions/FactoryHelpers.h>
|
||||||
|
|
||||||
|
/// These must be exposed in header for the purpose of dynamic compilation.
|
||||||
|
#include <AggregateFunctions/QuantileReservoirSampler.h>
|
||||||
|
#include <AggregateFunctions/QuantileReservoirSamplerDeterministic.h>
|
||||||
|
#include <AggregateFunctions/QuantileExact.h>
|
||||||
|
#include <AggregateFunctions/QuantileExactWeighted.h>
|
||||||
|
#include <AggregateFunctions/QuantileTiming.h>
|
||||||
|
#include <AggregateFunctions/QuantileTDigest.h>
|
||||||
|
|
||||||
|
#include <AggregateFunctions/IAggregateFunction.h>
|
||||||
|
#include <AggregateFunctions/QuantilesCommon.h>
|
||||||
|
#include <Columns/ColumnArray.h>
|
||||||
|
#include <Columns/ColumnDecimal.h>
|
||||||
|
#include <Columns/ColumnsNumber.h>
|
||||||
|
#include <DataTypes/DataTypeArray.h>
|
||||||
|
#include <DataTypes/DataTypeDate.h>
|
||||||
|
#include <DataTypes/DataTypeDateTime.h>
|
||||||
|
#include <DataTypes/DataTypesNumber.h>
|
||||||
|
#include <IO/ReadHelpers.h>
|
||||||
|
#include <IO/WriteHelpers.h>
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <AggregateFunctions/UniqVariadicHash.h>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/** Calculates Shannon Entropy, using HashMap and computing empirical distribution function
|
||||||
|
*/
|
||||||
|
template <typename Value>
|
||||||
|
struct EntropyData
|
||||||
|
{
|
||||||
|
using Weight = UInt64;
|
||||||
|
using Map = HashMap <
|
||||||
|
Value, Weight,
|
||||||
|
HashCRC32<Value>,
|
||||||
|
HashTableGrower<4>,
|
||||||
|
HashTableAllocatorWithStackMemory<sizeof(std::pair<Value, Weight>) * (1 << 3)>
|
||||||
|
>;
|
||||||
|
|
||||||
|
Map map;
|
||||||
|
|
||||||
|
void add(const Value &x)
|
||||||
|
{
|
||||||
|
if (!isNaN(x))
|
||||||
|
++map[x];
|
||||||
|
}
|
||||||
|
|
||||||
|
void add(const Value &x, const Weight &weight)
|
||||||
|
{
|
||||||
|
if (!isNaN(x))
|
||||||
|
map[x] += weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge(const EntropyData &rhs)
|
||||||
|
{
|
||||||
|
for (const auto &pair : rhs.map)
|
||||||
|
map[pair.first] += pair.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
void serialize(WriteBuffer &buf) const
|
||||||
|
{
|
||||||
|
map.write(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void deserialize(ReadBuffer &buf)
|
||||||
|
{
|
||||||
|
typename Map::Reader reader(buf);
|
||||||
|
while (reader.next())
|
||||||
|
{
|
||||||
|
const auto &pair = reader.get();
|
||||||
|
map[pair.first] = pair.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Float64 get() const
|
||||||
|
{
|
||||||
|
Float64 ShannonEntropy = 0;
|
||||||
|
UInt64 TotalValue = 0;
|
||||||
|
for (const auto & pair : map)
|
||||||
|
{
|
||||||
|
TotalValue += pair.second;
|
||||||
|
}
|
||||||
|
Float64 cur_proba;
|
||||||
|
Float64 log2e = 1 / std::log(2);
|
||||||
|
for (const auto & pair : map)
|
||||||
|
{
|
||||||
|
cur_proba = Float64(pair.second) / TotalValue;
|
||||||
|
ShannonEntropy -= cur_proba * std::log(cur_proba) * log2e;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ShannonEntropy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Value, typename UserValue>
|
||||||
|
class AggregateFunctionEntropy final : public IAggregateFunctionDataHelper<EntropyData<Value>,
|
||||||
|
AggregateFunctionEntropy<Value, UserValue>>
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
AggregateFunctionEntropy()
|
||||||
|
{}
|
||||||
|
|
||||||
|
String getName() const override { return "entropy"; }
|
||||||
|
|
||||||
|
DataTypePtr getReturnType() const override
|
||||||
|
{
|
||||||
|
return std::make_shared<DataTypeNumber<Float64>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||||
|
{
|
||||||
|
if constexpr (!std::is_same_v<UInt128, Value>)
|
||||||
|
{
|
||||||
|
const auto &column = static_cast<const ColumnVector <UserValue> &>(*columns[0]);
|
||||||
|
this->data(place).add(column.getData()[row_num]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
this->data(place).add(UniqVariadicHash<true, false>::apply(1, columns, row_num));
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||||
|
{
|
||||||
|
this->data(place).merge(this->data(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||||
|
{
|
||||||
|
this->data(const_cast<AggregateDataPtr>(place)).serialize(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
|
||||||
|
{
|
||||||
|
this->data(place).deserialize(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||||
|
{
|
||||||
|
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
|
||||||
|
column.getData().push_back(this->data(place).get());
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
@ -19,7 +19,7 @@ namespace ErrorCodes
|
|||||||
/** Calculates quantile by collecting all values into array
|
/** Calculates quantile by collecting all values into array
|
||||||
* and applying n-th element (introselect) algorithm for the resulting array.
|
* and applying n-th element (introselect) algorithm for the resulting array.
|
||||||
*
|
*
|
||||||
* It use O(N) memory and it is very inefficient in case of high amount of identical values.
|
* It uses O(N) memory and it is very inefficient in case of high amount of identical values.
|
||||||
* But it is very CPU efficient for not large datasets.
|
* But it is very CPU efficient for not large datasets.
|
||||||
*/
|
*/
|
||||||
template <typename Value>
|
template <typename Value>
|
||||||
|
@ -14,7 +14,7 @@ namespace ErrorCodes
|
|||||||
|
|
||||||
/** Calculates quantile by counting number of occurrences for each value in a hash map.
|
/** Calculates quantile by counting number of occurrences for each value in a hash map.
|
||||||
*
|
*
|
||||||
* It use O(distinct(N)) memory. Can be naturally applied for values with weight.
|
* It uses O(distinct(N)) memory. Can be naturally applied for values with weight.
|
||||||
* In case of many identical values, it can be more efficient than QuantileExact even when weight is not used.
|
* In case of many identical values, it can be more efficient than QuantileExact even when weight is not used.
|
||||||
*/
|
*/
|
||||||
template <typename Value>
|
template <typename Value>
|
||||||
|
@ -27,6 +27,7 @@ void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory &);
|
|||||||
void registerAggregateFunctionTopK(AggregateFunctionFactory &);
|
void registerAggregateFunctionTopK(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionsBitwise(AggregateFunctionFactory &);
|
void registerAggregateFunctionsBitwise(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
|
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
|
||||||
|
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
|
||||||
|
|
||||||
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
||||||
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
|
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
|
||||||
@ -65,6 +66,7 @@ void registerAggregateFunctions()
|
|||||||
registerAggregateFunctionsMaxIntersections(factory);
|
registerAggregateFunctionsMaxIntersections(factory);
|
||||||
registerAggregateFunctionHistogram(factory);
|
registerAggregateFunctionHistogram(factory);
|
||||||
registerAggregateFunctionRetention(factory);
|
registerAggregateFunctionRetention(factory);
|
||||||
|
registerAggregateFunctionEntropy(factory);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
5
dbms/tests/queries/0_stateless/00902_entropy.reference
Normal file
5
dbms/tests/queries/0_stateless/00902_entropy.reference
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
||||||
|
1
|
49
dbms/tests/queries/0_stateless/00902_entropy.sql
Normal file
49
dbms/tests/queries/0_stateless/00902_entropy.sql
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
CREATE DATABASE IF NOT EXISTS test;
|
||||||
|
DROP TABLE IF EXISTS test.defaults;
|
||||||
|
CREATE TABLE IF NOT EXISTS test.defaults
|
||||||
|
(
|
||||||
|
vals String
|
||||||
|
) ENGINE = Memory;
|
||||||
|
|
||||||
|
insert into test.defaults values ('ba'), ('aa'), ('ba'), ('b'), ('ba'), ('aa');
|
||||||
|
select val < 1.5 and val > 1.459 from (select entropy(vals) as val from test.defaults);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE DATABASE IF NOT EXISTS test;
|
||||||
|
DROP TABLE IF EXISTS test.defaults;
|
||||||
|
CREATE TABLE IF NOT EXISTS test.defaults
|
||||||
|
(
|
||||||
|
vals UInt64
|
||||||
|
) ENGINE = Memory;
|
||||||
|
insert into test.defaults values (0), (0), (1), (0), (0), (0), (1), (2), (3), (5), (3), (1), (1), (4), (5), (2)
|
||||||
|
select val < 2.4 and val > 2.3393 from (select entropy(vals) as val from test.defaults);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE DATABASE IF NOT EXISTS test;
|
||||||
|
DROP TABLE IF EXISTS test.defaults;
|
||||||
|
CREATE TABLE IF NOT EXISTS test.defaults
|
||||||
|
(
|
||||||
|
vals UInt32
|
||||||
|
) ENGINE = Memory;
|
||||||
|
insert into test.defaults values (0), (0), (1), (0), (0), (0), (1), (2), (3), (5), (3), (1), (1), (4), (5), (2)
|
||||||
|
select val < 2.4 and val > 2.3393 from (select entropy(vals) as val from test.defaults);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE DATABASE IF NOT EXISTS test;
|
||||||
|
DROP TABLE IF EXISTS test.defaults;
|
||||||
|
CREATE TABLE IF NOT EXISTS test.defaults
|
||||||
|
(
|
||||||
|
vals Int32
|
||||||
|
) ENGINE = Memory;
|
||||||
|
insert into test.defaults values (0), (0), (-1), (0), (0), (0), (-1), (2), (3), (5), (3), (-1), (-1), (4), (5), (2)
|
||||||
|
select val < 2.4 and val > 2.3393 from (select entropy(vals) as val from test.defaults);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE DATABASE IF NOT EXISTS test;
|
||||||
|
DROP TABLE IF EXISTS test.defaults;
|
||||||
|
CREATE TABLE IF NOT EXISTS test.defaults
|
||||||
|
(
|
||||||
|
vals DateTime
|
||||||
|
) ENGINE = Memory;
|
||||||
|
insert into test.defaults values (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 24:00:00')), (toDateTime('2016-06-15 24:00:00')), (toDateTime('2016-06-15 24:00:00')), (toDateTime('2017-06-15 24:00:00')), (toDateTime('2017-06-15 24:00:00')), (toDateTime('2018-06-15 24:00:00')), (toDateTime('2018-06-15 24:00:00')), (toDateTime('2019-06-15 24:00:00'));
|
||||||
|
select val < 2.189 and val > 2.1886 from (select entropy(vals) as val from test.defaults);
|
Loading…
Reference in New Issue
Block a user