Merge pull request #4321 from yandex/entropy-rework

Fixed entropy aggregate function
This commit is contained in:
alexey-milovidov 2019-02-10 01:26:45 +03:00 committed by GitHub
commit bb6d70cae6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 103 additions and 72 deletions

View File

@ -1,6 +1,8 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionEntropy.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
namespace DB
{
@ -20,32 +22,16 @@ AggregateFunctionPtr createAggregateFunctionEntropy(const std::string & name, co
throw Exception("Incorrect number of arguments for aggregate function " + name,
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
WhichDataType which(argument_types[0]);
if (isNumber(argument_types[0]))
size_t num_args = argument_types.size();
if (num_args == 1)
{
if (which.isUInt64())
{
return std::make_shared<AggregateFunctionEntropy<UInt64>>();
}
else if (which.isInt64())
{
return std::make_shared<AggregateFunctionEntropy<Int64>>();
}
else if (which.isInt32())
{
return std::make_shared<AggregateFunctionEntropy<Int32>>();
}
else if (which.isUInt32())
{
return std::make_shared<AggregateFunctionEntropy<UInt32>>();
}
else if (which.isUInt128())
{
return std::make_shared<AggregateFunctionEntropy<UInt128, true>>();
}
/// Specialized implementation for single argument of numeric type.
if (auto res = createWithNumericBasedType<AggregateFunctionEntropy>(*argument_types[0], num_args))
return AggregateFunctionPtr(res);
}
return std::make_shared<AggregateFunctionEntropy<UInt128>>();
/// Generic implementation for other types or for multiple arguments.
return std::make_shared<AggregateFunctionEntropy<UInt128>>(num_args);
}
}

View File

@ -1,43 +1,41 @@
#pragma once
#include <AggregateFunctions/FactoryHelpers.h>
#include <Common/HashTable/HashMap.h>
#include <Common/NaNUtils.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/UniqVariadicHash.h>
#include <Columns/ColumnArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Columns/ColumnVector.h>
#include <cmath>
namespace DB
{
/** Calculates Shannon Entropy, using HashMap and computing empirical distribution function
/** Calculates Shannon Entropy, using HashMap and computing empirical distribution function.
* Entropy is measured in bits (base-2 logarithm is used).
*/
template <typename Value, bool is_hashed>
template <typename Value>
struct EntropyData
{
using Weight = UInt64;
using HashingMap = HashMap <
Value, Weight,
HashCRC32<Value>,
HashTableGrower<4>,
HashTableAllocatorWithStackMemory<sizeof(std::pair<Value, Weight>) * (1 << 3)>
>;
using TrivialMap = HashMap <
Value, Weight,
UInt128TrivialHash,
HashTableGrower<4>,
HashTableAllocatorWithStackMemory<sizeof(std::pair<Value, Weight>) * (1 << 3)>
>;
using HashingMap = HashMap<
Value, Weight,
HashCRC32<Value>,
HashTableGrower<4>,
HashTableAllocatorWithStackMemory<sizeof(std::pair<Value, Weight>) * (1 << 3)>>;
/// If column value is UInt128 then there is no need to hash values
using Map = std::conditional_t<is_hashed, TrivialMap, HashingMap>;
/// For the case of pre-hashed values.
using TrivialMap = HashMap<
Value, Weight,
UInt128TrivialHash,
HashTableGrower<4>,
HashTableAllocatorWithStackMemory<sizeof(std::pair<Value, Weight>) * (1 << 3)>>;
using Map = std::conditional_t<std::is_same_v<UInt128, Value>, TrivialMap, HashingMap>;
Map map;
@ -69,38 +67,39 @@ struct EntropyData
typename Map::Reader reader(buf);
while (reader.next())
{
const auto &pair = reader.get();
const auto & pair = reader.get();
map[pair.first] = pair.second;
}
}
Float64 get() const
{
Float64 shannon_entropy = 0;
UInt64 total_value = 0;
for (const auto & pair : map)
{
total_value += pair.second;
}
Float64 cur_proba;
Float64 log2e = 1 / std::log(2);
Float64 shannon_entropy = 0;
for (const auto & pair : map)
{
cur_proba = Float64(pair.second) / total_value;
shannon_entropy -= cur_proba * std::log(cur_proba) * log2e;
Float64 frequency = Float64(pair.second) / total_value;
shannon_entropy -= frequency * log2(frequency);
}
return shannon_entropy;
}
};
template <typename Value, bool is_hashed = false>
class AggregateFunctionEntropy final : public IAggregateFunctionDataHelper<EntropyData<Value, is_hashed>,
AggregateFunctionEntropy<Value>>
template <typename Value>
class AggregateFunctionEntropy final : public IAggregateFunctionDataHelper<EntropyData<Value>, AggregateFunctionEntropy<Value>>
{
private:
size_t num_args;
public:
AggregateFunctionEntropy()
{}
AggregateFunctionEntropy(size_t num_args) : num_args(num_args)
{
}
String getName() const override { return "entropy"; }
@ -114,13 +113,12 @@ public:
if constexpr (!std::is_same_v<UInt128, Value>)
{
/// Here we manage only with numerical types
const auto &column = static_cast<const ColumnVector <Value> &>(*columns[0]);
const auto & column = static_cast<const ColumnVector <Value> &>(*columns[0]);
this->data(place).add(column.getData()[row_num]);
}
else
{
this->data(place).add(UniqVariadicHash<true, false>::apply(1, columns, row_num));
this->data(place).add(UniqVariadicHash<true, false>::apply(num_args, columns, row_num));
}
}
@ -141,12 +139,11 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
auto & column = static_cast<ColumnVector<Float64> &>(to);
column.getData().push_back(this->data(place).get());
}
const char * getHeaderFilePath() const override { return __FILE__; }
};
}

View File

@ -41,7 +41,7 @@ template <typename T> using FuncQuantilesTDigestWeighted = AggregateFunctionQuan
template <template <typename> class Function>
static constexpr bool SupportDecimal()
static constexpr bool supportDecimal()
{
return std::is_same_v<Function<Float32>, FuncQuantileExact<Float32>> ||
std::is_same_v<Function<Float32>, FuncQuantilesExact<Float32>>;
@ -61,11 +61,10 @@ AggregateFunctionPtr createAggregateFunctionQuantile(const std::string & name, c
if (which.idx == TypeIndex::TYPE) return std::make_shared<Function<TYPE>>(argument_type, params);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
#undef FOR_NUMERIC_TYPES
if (which.idx == TypeIndex::Date) return std::make_shared<Function<DataTypeDate::FieldType>>(argument_type, params);
if (which.idx == TypeIndex::DateTime) return std::make_shared<Function<DataTypeDateTime::FieldType>>(argument_type, params);
if constexpr (SupportDecimal<Function>())
if constexpr (supportDecimal<Function>())
{
if (which.idx == TypeIndex::Decimal32) return std::make_shared<Function<Decimal32>>(argument_type, params);
if (which.idx == TypeIndex::Decimal64) return std::make_shared<Function<Decimal64>>(argument_type, params);

View File

@ -20,7 +20,7 @@ namespace DB
/** Create an aggregate function with a numeric type in the template parameter, depending on the type of the argument.
*/
template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction * createWithNumericType(const IDataType & argument_type, TArgs && ... args)
{
WhichDataType which(argument_type);
@ -33,7 +33,7 @@ static IAggregateFunction * createWithNumericType(const IDataType & argument_typ
return nullptr;
}
template <template <typename, typename> class AggregateFunctionTemplate, typename Data, typename ... TArgs>
template <template <typename, typename> class AggregateFunctionTemplate, typename Data, typename... TArgs>
static IAggregateFunction * createWithNumericType(const IDataType & argument_type, TArgs && ... args)
{
WhichDataType which(argument_type);
@ -46,7 +46,7 @@ static IAggregateFunction * createWithNumericType(const IDataType & argument_typ
return nullptr;
}
template <template <typename, typename> class AggregateFunctionTemplate, template <typename> class Data, typename ... TArgs>
template <template <typename, typename> class AggregateFunctionTemplate, template <typename> class Data, typename... TArgs>
static IAggregateFunction * createWithNumericType(const IDataType & argument_type, TArgs && ... args)
{
WhichDataType which(argument_type);
@ -59,7 +59,7 @@ static IAggregateFunction * createWithNumericType(const IDataType & argument_typ
return nullptr;
}
template <template <typename, typename> class AggregateFunctionTemplate, template <typename> class Data, typename ... TArgs>
template <template <typename, typename> class AggregateFunctionTemplate, template <typename> class Data, typename... TArgs>
static IAggregateFunction * createWithUnsignedIntegerType(const IDataType & argument_type, TArgs && ... args)
{
WhichDataType which(argument_type);
@ -70,7 +70,7 @@ static IAggregateFunction * createWithUnsignedIntegerType(const IDataType & argu
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction * createWithNumericBasedType(const IDataType & argument_type, TArgs && ... args)
{
IAggregateFunction * f = createWithNumericType<AggregateFunctionTemplate>(argument_type, std::forward<TArgs>(args)...);
@ -85,7 +85,7 @@ static IAggregateFunction * createWithNumericBasedType(const IDataType & argumen
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>
template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction * createWithDecimalType(const IDataType & argument_type, TArgs && ... args)
{
WhichDataType which(argument_type);
@ -98,7 +98,7 @@ static IAggregateFunction * createWithDecimalType(const IDataType & argument_typ
/** For template with two arguments.
*/
template <typename FirstType, template <typename, typename> class AggregateFunctionTemplate, typename ... TArgs>
template <typename FirstType, template <typename, typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction * createWithTwoNumericTypesSecond(const IDataType & second_type, TArgs && ... args)
{
WhichDataType which(second_type);
@ -111,7 +111,7 @@ static IAggregateFunction * createWithTwoNumericTypesSecond(const IDataType & se
return nullptr;
}
template <template <typename, typename> class AggregateFunctionTemplate, typename ... TArgs>
template <template <typename, typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction * createWithTwoNumericTypes(const IDataType & first_type, const IDataType & second_type, TArgs && ... args)
{
WhichDataType which(first_type);

View File

@ -0,0 +1,41 @@
<test>
<name>entropy</name>
<type>loop</type>
<preconditions>
<table_exists>test.hits</table_exists>
</preconditions>
<stop_conditions>
<all_of>
<total_time_ms>10000</total_time_ms>
</all_of>
<any_of>
<average_speed_not_changing_for_ms>5000</average_speed_not_changing_for_ms>
<total_time_ms>20000</total_time_ms>
</any_of>
</stop_conditions>
<main_metric>
<min_time/>
</main_metric>
<substitutions>
<substitution>
<name>args</name>
<values>
<value>SearchEngineID</value>
<value>SearchPhrase</value>
<value>MobilePhoneModel</value>
<value>URL</value>
<value>URLDomain</value>
<value>URL, URLDomain</value>
<value>ClientIP</value>
<value>RegionID</value>
<value>ClientIP, RegionID</value>
</values>
</substitution>
</substitutions>
<query>SELECT entropy({args}) FROM test.hits</query>
</test>

View File

@ -0,0 +1,2 @@
8
1

View File

@ -0,0 +1,2 @@
SELECT round(entropy(number), 6) FROM remote('127.0.0.{1,2}', numbers(256));
SELECT entropy(rand64()) > 8 FROM remote('127.0.0.{1,2}', numbers(256));

View File

@ -0,0 +1,2 @@
1
1 1

View File

@ -0,0 +1,2 @@
SELECT max(x) - min(x) < 0.000001 FROM (WITH entropy(number % 2, number % 5) AS e1, log2(10) AS e2, log2(uniq(number % 2, number % 5)) AS e3, entropy(number) AS e4, entropy(toString(number)) AS e5, entropy(number % 2 ? 'hello' : 'world', range(number % 5)) AS e6, entropy(number, number + 1, number - 1) AS e7, entropy(([[number], [number, number]], [[], [number]])) AS e8 SELECT arrayJoin([e1, e2, e3, e4, e5, e6, e7, e8]) AS x FROM numbers(10));
SELECT abs(entropy(number) - 8) < 0.000001, abs(entropy(number % 64, number % 32) - 6) < 0.000001 FROM numbers(256);