mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
Aggregate function for entropy
This commit is contained in:
parent
0f577da5c2
commit
8c2726b77c
55
dbms/src/AggregateFunctions/AggregateFunctionEntropy.cpp
Normal file
55
dbms/src/AggregateFunctions/AggregateFunctionEntropy.cpp
Normal file
@ -0,0 +1,55 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionEntropy.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionEntropy(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
if (argument_types.empty())
|
||||
throw Exception("Incorrect number of arguments for aggregate function " + name,
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
WhichDataType which(argument_types[0]);
|
||||
if (isNumber(argument_types[0]))
|
||||
{
|
||||
if (which.isUInt64())
|
||||
{
|
||||
return std::make_shared<AggregateFunctionEntropy<UInt64, UInt64>>();
|
||||
}
|
||||
else if (which.isInt64())
|
||||
{
|
||||
return std::make_shared<AggregateFunctionEntropy<Int64, Int64>>();
|
||||
}
|
||||
else if (which.isInt32())
|
||||
{
|
||||
return std::make_shared<AggregateFunctionEntropy<Int32, Int32>>();
|
||||
}
|
||||
else if (which.isUInt32())
|
||||
{
|
||||
return std::make_shared<AggregateFunctionEntropy<UInt32, UInt32>>();
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_shared<AggregateFunctionEntropy<UInt128, String>>();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionEntropy(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("entropy", createAggregateFunctionEntropy);
|
||||
}
|
||||
|
||||
}
|
161
dbms/src/AggregateFunctions/AggregateFunctionEntropy.h
Normal file
161
dbms/src/AggregateFunctions/AggregateFunctionEntropy.h
Normal file
@ -0,0 +1,161 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
/// These must be exposed in header for the purpose of dynamic compilation.
|
||||
#include <AggregateFunctions/QuantileReservoirSampler.h>
|
||||
#include <AggregateFunctions/QuantileReservoirSamplerDeterministic.h>
|
||||
#include <AggregateFunctions/QuantileExact.h>
|
||||
#include <AggregateFunctions/QuantileExactWeighted.h>
|
||||
#include <AggregateFunctions/QuantileTiming.h>
|
||||
#include <AggregateFunctions/QuantileTDigest.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/QuantilesCommon.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnDecimal.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include <cmath>
|
||||
#include <AggregateFunctions/UniqVariadicHash.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
}
|
||||
|
||||
|
||||
/** Calculates Shannon Entropy, using HashMap and computing empirical distribution function
|
||||
*/
|
||||
template <typename Value>
|
||||
struct EntropyData
|
||||
{
|
||||
using Weight = UInt64;
|
||||
using Map = HashMap <
|
||||
Value, Weight,
|
||||
HashCRC32<Value>,
|
||||
HashTableGrower<4>,
|
||||
HashTableAllocatorWithStackMemory<sizeof(std::pair<Value, Weight>) * (1 << 3)>
|
||||
>;
|
||||
|
||||
Map map;
|
||||
|
||||
void add(const Value &x)
|
||||
{
|
||||
if (!isNaN(x))
|
||||
++map[x];
|
||||
}
|
||||
|
||||
void add(const Value &x, const Weight &weight)
|
||||
{
|
||||
if (!isNaN(x))
|
||||
map[x] += weight;
|
||||
}
|
||||
|
||||
void merge(const EntropyData &rhs)
|
||||
{
|
||||
for (const auto &pair : rhs.map)
|
||||
map[pair.first] += pair.second;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer &buf) const
|
||||
{
|
||||
map.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer &buf)
|
||||
{
|
||||
typename Map::Reader reader(buf);
|
||||
while (reader.next())
|
||||
{
|
||||
const auto &pair = reader.get();
|
||||
map[pair.first] = pair.second;
|
||||
}
|
||||
}
|
||||
|
||||
Float64 get() const
|
||||
{
|
||||
Float64 ShannonEntropy = 0;
|
||||
UInt64 TotalValue = 0;
|
||||
for (const auto & pair : map)
|
||||
{
|
||||
TotalValue += pair.second;
|
||||
}
|
||||
Float64 cur_proba;
|
||||
Float64 log2e = 1 / std::log(2);
|
||||
for (const auto & pair : map)
|
||||
{
|
||||
cur_proba = Float64(pair.second) / TotalValue;
|
||||
ShannonEntropy -= cur_proba * std::log(cur_proba) * log2e;
|
||||
}
|
||||
|
||||
return ShannonEntropy;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Value, typename UserValue>
|
||||
class AggregateFunctionEntropy final : public IAggregateFunctionDataHelper<EntropyData<Value>,
|
||||
AggregateFunctionEntropy<Value, UserValue>>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionEntropy()
|
||||
{}
|
||||
|
||||
String getName() const override { return "entropy"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
if constexpr (!std::is_same_v<UInt128, Value>)
|
||||
{
|
||||
const auto &column = static_cast<const ColumnVector <UserValue> &>(*columns[0]);
|
||||
this->data(place).add(column.getData()[row_num]);
|
||||
}
|
||||
else
|
||||
{
|
||||
this->data(place).add(UniqVariadicHash<true, false>::apply(1, columns, row_num));
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||
{
|
||||
this->data(const_cast<AggregateDataPtr>(place)).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
|
||||
column.getData().push_back(this->data(place).get());
|
||||
}
|
||||
|
||||
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||
|
||||
};
|
||||
|
||||
}
|
@ -19,7 +19,7 @@ namespace ErrorCodes
|
||||
/** Calculates quantile by collecting all values into array
|
||||
* and applying n-th element (introselect) algorithm for the resulting array.
|
||||
*
|
||||
* It use O(N) memory and it is very inefficient in case of high amount of identical values.
|
||||
* It uses O(N) memory and it is very inefficient in case of high amount of identical values.
|
||||
* But it is very CPU efficient for not large datasets.
|
||||
*/
|
||||
template <typename Value>
|
||||
|
@ -14,7 +14,7 @@ namespace ErrorCodes
|
||||
|
||||
/** Calculates quantile by counting number of occurrences for each value in a hash map.
|
||||
*
|
||||
* It use O(distinct(N)) memory. Can be naturally applied for values with weight.
|
||||
* It uses O(distinct(N)) memory. Can be naturally applied for values with weight.
|
||||
* In case of many identical values, it can be more efficient than QuantileExact even when weight is not used.
|
||||
*/
|
||||
template <typename Value>
|
||||
|
@ -27,6 +27,7 @@ void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionTopK(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionsBitwise(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
|
||||
|
||||
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
||||
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
|
||||
@ -65,6 +66,7 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionsMaxIntersections(factory);
|
||||
registerAggregateFunctionHistogram(factory);
|
||||
registerAggregateFunctionRetention(factory);
|
||||
registerAggregateFunctionEntropy(factory);
|
||||
}
|
||||
|
||||
{
|
||||
|
5
dbms/tests/queries/0_stateless/00902_entropy.reference
Normal file
5
dbms/tests/queries/0_stateless/00902_entropy.reference
Normal file
@ -0,0 +1,5 @@
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
49
dbms/tests/queries/0_stateless/00902_entropy.sql
Normal file
49
dbms/tests/queries/0_stateless/00902_entropy.sql
Normal file
@ -0,0 +1,49 @@
|
||||
CREATE DATABASE IF NOT EXISTS test;
|
||||
DROP TABLE IF EXISTS test.defaults;
|
||||
CREATE TABLE IF NOT EXISTS test.defaults
|
||||
(
|
||||
vals String
|
||||
) ENGINE = Memory;
|
||||
|
||||
insert into test.defaults values ('ba'), ('aa'), ('ba'), ('b'), ('ba'), ('aa');
|
||||
select val < 1.5 and val > 1.459 from (select entropy(vals) as val from test.defaults);
|
||||
|
||||
|
||||
CREATE DATABASE IF NOT EXISTS test;
|
||||
DROP TABLE IF EXISTS test.defaults;
|
||||
CREATE TABLE IF NOT EXISTS test.defaults
|
||||
(
|
||||
vals UInt64
|
||||
) ENGINE = Memory;
|
||||
insert into test.defaults values (0), (0), (1), (0), (0), (0), (1), (2), (3), (5), (3), (1), (1), (4), (5), (2)
|
||||
select val < 2.4 and val > 2.3393 from (select entropy(vals) as val from test.defaults);
|
||||
|
||||
|
||||
CREATE DATABASE IF NOT EXISTS test;
|
||||
DROP TABLE IF EXISTS test.defaults;
|
||||
CREATE TABLE IF NOT EXISTS test.defaults
|
||||
(
|
||||
vals UInt32
|
||||
) ENGINE = Memory;
|
||||
insert into test.defaults values (0), (0), (1), (0), (0), (0), (1), (2), (3), (5), (3), (1), (1), (4), (5), (2)
|
||||
select val < 2.4 and val > 2.3393 from (select entropy(vals) as val from test.defaults);
|
||||
|
||||
|
||||
CREATE DATABASE IF NOT EXISTS test;
|
||||
DROP TABLE IF EXISTS test.defaults;
|
||||
CREATE TABLE IF NOT EXISTS test.defaults
|
||||
(
|
||||
vals Int32
|
||||
) ENGINE = Memory;
|
||||
insert into test.defaults values (0), (0), (-1), (0), (0), (0), (-1), (2), (3), (5), (3), (-1), (-1), (4), (5), (2)
|
||||
select val < 2.4 and val > 2.3393 from (select entropy(vals) as val from test.defaults);
|
||||
|
||||
|
||||
CREATE DATABASE IF NOT EXISTS test;
|
||||
DROP TABLE IF EXISTS test.defaults;
|
||||
CREATE TABLE IF NOT EXISTS test.defaults
|
||||
(
|
||||
vals DateTime
|
||||
) ENGINE = Memory;
|
||||
insert into test.defaults values (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 24:00:00')), (toDateTime('2016-06-15 24:00:00')), (toDateTime('2016-06-15 24:00:00')), (toDateTime('2017-06-15 24:00:00')), (toDateTime('2017-06-15 24:00:00')), (toDateTime('2018-06-15 24:00:00')), (toDateTime('2018-06-15 24:00:00')), (toDateTime('2019-06-15 24:00:00'));
|
||||
select val < 2.189 and val > 2.1886 from (select entropy(vals) as val from test.defaults);
|
Loading…
Reference in New Issue
Block a user