Merge pull request #4238 from Quid37/yandex_open_code_competition

Implement Shannon entropy aggregate function
This commit is contained in:
alexey-milovidov 2019-02-04 18:43:29 +03:00 committed by GitHub
commit 2216250a84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 268 additions and 2 deletions

View File

@ -0,0 +1,58 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionEntropy.h>
#include <AggregateFunctions/FactoryHelpers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
namespace
{
AggregateFunctionPtr createAggregateFunctionEntropy(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertNoParameters(name, parameters);
if (argument_types.empty())
throw Exception("Incorrect number of arguments for aggregate function " + name,
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
WhichDataType which(argument_types[0]);
if (isNumber(argument_types[0]))
{
if (which.isUInt64())
{
return std::make_shared<AggregateFunctionEntropy<UInt64>>();
}
else if (which.isInt64())
{
return std::make_shared<AggregateFunctionEntropy<Int64>>();
}
else if (which.isInt32())
{
return std::make_shared<AggregateFunctionEntropy<Int32>>();
}
else if (which.isUInt32())
{
return std::make_shared<AggregateFunctionEntropy<UInt32>>();
}
else if (which.isUInt128())
{
return std::make_shared<AggregateFunctionEntropy<UInt128, true>>();
}
}
return std::make_shared<AggregateFunctionEntropy<UInt128>>();
}
}
void registerAggregateFunctionEntropy(AggregateFunctionFactory & factory)
{
factory.registerFunction("entropy", createAggregateFunctionEntropy);
}
}

View File

@ -0,0 +1,152 @@
#pragma once
#include <AggregateFunctions/FactoryHelpers.h>
#include <Common/HashTable/HashMap.h>
#include <Common/NaNUtils.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/UniqVariadicHash.h>
#include <Columns/ColumnArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <cmath>
namespace DB
{
/** Calculates Shannon Entropy, using HashMap and computing empirical distribution function
*/
template <typename Value, bool is_hashed>
struct EntropyData
{
using Weight = UInt64;
using HashingMap = HashMap <
Value, Weight,
HashCRC32<Value>,
HashTableGrower<4>,
HashTableAllocatorWithStackMemory<sizeof(std::pair<Value, Weight>) * (1 << 3)>
>;
using TrivialMap = HashMap <
Value, Weight,
UInt128TrivialHash,
HashTableGrower<4>,
HashTableAllocatorWithStackMemory<sizeof(std::pair<Value, Weight>) * (1 << 3)>
>;
/// If column value is UInt128 then there is no need to hash values
using Map = std::conditional_t<is_hashed, TrivialMap, HashingMap>;
Map map;
void add(const Value & x)
{
if (!isNaN(x))
++map[x];
}
void add(const Value & x, const Weight & weight)
{
if (!isNaN(x))
map[x] += weight;
}
void merge(const EntropyData & rhs)
{
for (const auto & pair : rhs.map)
map[pair.first] += pair.second;
}
void serialize(WriteBuffer & buf) const
{
map.write(buf);
}
void deserialize(ReadBuffer & buf)
{
typename Map::Reader reader(buf);
while (reader.next())
{
const auto &pair = reader.get();
map[pair.first] = pair.second;
}
}
Float64 get() const
{
Float64 shannon_entropy = 0;
UInt64 total_value = 0;
for (const auto & pair : map)
{
total_value += pair.second;
}
Float64 cur_proba;
Float64 log2e = 1 / std::log(2);
for (const auto & pair : map)
{
cur_proba = Float64(pair.second) / total_value;
shannon_entropy -= cur_proba * std::log(cur_proba) * log2e;
}
return shannon_entropy;
}
};
template <typename Value, bool is_hashed = false>
class AggregateFunctionEntropy final : public IAggregateFunctionDataHelper<EntropyData<Value, is_hashed>,
AggregateFunctionEntropy<Value>>
{
public:
AggregateFunctionEntropy()
{}
String getName() const override { return "entropy"; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<Float64>>();
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if constexpr (!std::is_same_v<UInt128, Value>)
{
/// Here we manage only with numerical types
const auto &column = static_cast<const ColumnVector <Value> &>(*columns[0]);
this->data(place).add(column.getData()[row_num]);
}
else
{
this->data(place).add(UniqVariadicHash<true, false>::apply(1, columns, row_num));
}
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(const_cast<AggregateDataPtr>(place)).serialize(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
this->data(place).deserialize(buf);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
column.getData().push_back(this->data(place).get());
}
const char * getHeaderFilePath() const override { return __FILE__; }
};
}

View File

@ -19,7 +19,7 @@ namespace ErrorCodes
/** Calculates quantile by collecting all values into array
* and applying n-th element (introselect) algorithm for the resulting array.
*
* It use O(N) memory and it is very inefficient in case of high amount of identical values.
* It uses O(N) memory and it is very inefficient in case of high amount of identical values.
* But it is very CPU efficient for not large datasets.
*/
template <typename Value>

View File

@ -14,7 +14,7 @@ namespace ErrorCodes
/** Calculates quantile by counting number of occurrences for each value in a hash map.
*
* It use O(distinct(N)) memory. Can be naturally applied for values with weight.
* It uses O(distinct(N)) memory. Can be naturally applied for values with weight.
* In case of many identical values, it can be more efficient than QuantileExact even when weight is not used.
*/
template <typename Value>

View File

@ -27,6 +27,7 @@ void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory &);
void registerAggregateFunctionTopK(AggregateFunctionFactory &);
void registerAggregateFunctionsBitwise(AggregateFunctionFactory &);
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
@ -65,6 +66,7 @@ void registerAggregateFunctions()
registerAggregateFunctionsMaxIntersections(factory);
registerAggregateFunctionHistogram(factory);
registerAggregateFunctionRetention(factory);
registerAggregateFunctionEntropy(factory);
}
{

View File

@ -0,0 +1,5 @@
1
1
1
1
1

View File

@ -0,0 +1,49 @@
CREATE DATABASE IF NOT EXISTS test;
DROP TABLE IF EXISTS test.defaults;
CREATE TABLE IF NOT EXISTS test.defaults
(
vals String
) ENGINE = Memory;
insert into test.defaults values ('ba'), ('aa'), ('ba'), ('b'), ('ba'), ('aa');
select val < 1.5 and val > 1.459 from (select entropy(vals) as val from test.defaults);
CREATE DATABASE IF NOT EXISTS test;
DROP TABLE IF EXISTS test.defaults;
CREATE TABLE IF NOT EXISTS test.defaults
(
vals UInt64
) ENGINE = Memory;
insert into test.defaults values (0), (0), (1), (0), (0), (0), (1), (2), (3), (5), (3), (1), (1), (4), (5), (2)
select val < 2.4 and val > 2.3393 from (select entropy(vals) as val from test.defaults);
CREATE DATABASE IF NOT EXISTS test;
DROP TABLE IF EXISTS test.defaults;
CREATE TABLE IF NOT EXISTS test.defaults
(
vals UInt32
) ENGINE = Memory;
insert into test.defaults values (0), (0), (1), (0), (0), (0), (1), (2), (3), (5), (3), (1), (1), (4), (5), (2)
select val < 2.4 and val > 2.3393 from (select entropy(vals) as val from test.defaults);
CREATE DATABASE IF NOT EXISTS test;
DROP TABLE IF EXISTS test.defaults;
CREATE TABLE IF NOT EXISTS test.defaults
(
vals Int32
) ENGINE = Memory;
insert into test.defaults values (0), (0), (-1), (0), (0), (0), (-1), (2), (3), (5), (3), (-1), (-1), (4), (5), (2)
select val < 2.4 and val > 2.3393 from (select entropy(vals) as val from test.defaults);
CREATE DATABASE IF NOT EXISTS test;
DROP TABLE IF EXISTS test.defaults;
CREATE TABLE IF NOT EXISTS test.defaults
(
vals DateTime
) ENGINE = Memory;
insert into test.defaults values (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 23:00:00')), (toDateTime('2016-06-15 24:00:00')), (toDateTime('2016-06-15 24:00:00')), (toDateTime('2016-06-15 24:00:00')), (toDateTime('2017-06-15 24:00:00')), (toDateTime('2017-06-15 24:00:00')), (toDateTime('2018-06-15 24:00:00')), (toDateTime('2018-06-15 24:00:00')), (toDateTime('2019-06-15 24:00:00'));
select val < 2.189 and val > 2.1886 from (select entropy(vals) as val from test.defaults);