Merging contingency coefficients

This commit is contained in:
Alexey Milovidov 2022-01-02 21:50:41 +03:00
parent 9dc66e1e72
commit 4a094c2efd
9 changed files with 300 additions and 360 deletions

View File

@ -1,61 +1,53 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionCramersV.h>
#include <AggregateFunctions/CrossTab.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
#include "registerAggregateFunctions.h"
#include <memory>
#include <cmath>
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
namespace DB
{
namespace
{
struct ContingencyData : public AggregateFunctionCramersVData
struct ContingencyData : CrossTabData
{
Float64 get_result() const
static const char * getName()
{
if (cur_size < 2){
throw Exception("Aggregate function contingency coefficient requires at least 2 values in columns", ErrorCodes::BAD_ARGUMENTS);
}
return "contingency";
}
Float64 getResult() const
{
if (count < 2)
return std::numeric_limits<Float64>::quiet_NaN();
Float64 phi = 0.0;
for (const auto & cell : pairs) {
UInt128 hash_pair = cell.getKey();
UInt64 count_of_pair_tmp = cell.getMapped();
Float64 count_of_pair = Float64(count_of_pair_tmp);
UInt64 hash1 = (hash_pair << 64 >> 64);
UInt64 hash2 = (hash_pair >> 64);
for (const auto & [key, value_ab] : count_ab)
{
Float64 value_a = count_a.at(key.items[0]);
Float64 value_b = count_b.at(key.items[1]);
UInt64 count1_tmp = n_i.find(hash1)->getMapped();
UInt64 count2_tmp = n_j.find(hash2)->getMapped();
Float64 count1 = static_cast<Float64>(count1_tmp);
Float64 count2 = Float64(count2_tmp);
phi += ((count_of_pair * count_of_pair / (count1 * count2) * cur_size)
- 2 * count_of_pair + (count1 * count2 / cur_size));
phi += value_ab * value_ab / (value_a * value_b) * count - 2 * value_ab + (value_a * value_b) / count;
}
phi /= cur_size;
return sqrt(phi / (phi + cur_size));
phi /= count;
return sqrt(phi / (phi + count));
}
};
AggregateFunctionPtr createAggregateFunctionContingencyCoefficient(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
void registerAggregateFunctionContingency(AggregateFunctionFactory & factory)
{
assertNoParameters(name, parameters);
return std::make_shared<AggregateFunctionCramersV<ContingencyData>>(argument_types);
factory.registerFunction(ContingencyData::getName(),
[](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
return std::make_shared<AggregateFunctionCrossTab<ContingencyData>>(argument_types);
});
}
}
void registerAggregateFunctionContingencyCoefficient(AggregateFunctionFactory & factory)
{
factory.registerFunction("ContingencyCoefficient", createAggregateFunctionContingencyCoefficient);
}
}

View File

@ -1,27 +1,56 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionCramersV.h>
#include <AggregateFunctions/CrossTab.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
#include "registerAggregateFunctions.h"
#include <memory>
#include <cmath>
namespace DB
{
namespace
{
AggregateFunctionPtr createAggregateFunctionCramersV(const std::string & name, const DataTypes & argument_types,
const Array & parameters, const Settings *)
struct CramersVData : CrossTabData
{
assertNoParameters(name, parameters);
return std::make_shared<AggregateFunctionCramersV<AggregateFunctionCramersVData>>(argument_types);
}
static const char * getName()
{
return "cramersV";
}
Float64 getResult() const
{
if (count < 2)
return std::numeric_limits<Float64>::quiet_NaN();
Float64 phi = 0.0;
for (const auto & [key, value_ab] : count_ab)
{
Float64 value_a = count_a.at(key.items[0]);
Float64 value_b = count_b.at(key.items[1]);
phi += value_ab * value_ab / (value_a * value_b) * count - 2 * value_ab + (value_a * value_b) / count;
}
phi /= count;
UInt64 q = std::min(count_a.size(), count_b.size());
phi /= q - 1;
return sqrt(phi);
}
};
}
void registerAggregateFunctionCramersV(AggregateFunctionFactory & factory)
{
factory.registerFunction("CramersV", createAggregateFunctionCramersV);
factory.registerFunction(CramersVData::getName(),
[](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
return std::make_shared<AggregateFunctionCrossTab<CramersVData>>(argument_types);
});
}
}

View File

@ -1,207 +0,0 @@
#pragma once
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnTuple.h>
#include <Common/assert_cast.h>
#include <Common/FieldVisitors.h>
#include <Core/Types.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeTuple.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Common/HashTable/HashMap.h>
#include <AggregateFunctions/UniqVariadicHash.h>
#include <limits>
#include <cmath>
#include <type_traits>
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
namespace DB
{
struct AggregateFunctionCramersVData
{
size_t cur_size = 0;
HashMap<UInt64, UInt64> n_i;
HashMap<UInt64, UInt64> n_j;
HashMap<UInt128, UInt64> pairs;
void add(UInt64 hash1, UInt64 hash2)
{
cur_size += 1;
n_i[hash1] += 1;
n_j[hash2] += 1;
UInt128 hash_pair = hash1 | (static_cast<UInt128>(hash2) << 64);
pairs[hash_pair] += 1;
}
void merge(const AggregateFunctionCramersVData &other)
{
cur_size += other.cur_size;
for (const auto& pair : other.n_i) {
UInt64 hash1 = pair.getKey();
UInt64 count = pair.getMapped();
n_i[hash1] += count;
}
for (const auto& pair : other.n_j) {
UInt64 hash1 = pair.getKey();
UInt64 count = pair.getMapped();
n_j[hash1] += count;
}
for (const auto& pair : other.pairs) {
UInt128 hash1 = pair.getKey();
UInt64 count = pair.getMapped();
pairs[hash1] += count;
}
}
void serialize(WriteBuffer &buf) const
{
writeBinary(cur_size, buf);
n_i.write(buf);
n_j.write(buf);
pairs.write(buf);
}
void deserialize(ReadBuffer &buf)
{
readBinary(cur_size, buf);
n_i.read(buf);
n_j.read(buf);
pairs.read(buf);
}
Float64 get_result() const
{
if (cur_size < 2){
throw Exception("Aggregate function cramer's v requires et least 2 values in columns", ErrorCodes::BAD_ARGUMENTS);
}
Float64 phi = 0.0;
for (const auto & cell : pairs) {
UInt128 hash_pair = cell.getKey();
UInt64 count_of_pair_tmp = cell.getMapped();
Float64 count_of_pair = Float64(count_of_pair_tmp);
UInt64 hash1 = (hash_pair << 64 >> 64);
UInt64 hash2 = (hash_pair >> 64);
UInt64 count1_tmp = n_i.find(hash1)->getMapped();
UInt64 count2_tmp = n_j.find(hash2)->getMapped();
Float64 count1 = static_cast<Float64>(count1_tmp);
Float64 count2 = Float64(count2_tmp);
phi += ((count_of_pair * count_of_pair / (count1 * count2) * cur_size)
- 2 * count_of_pair + (count1 * count2 / cur_size));
}
phi /= cur_size;
UInt64 q = std::min(n_i.size(), n_j.size());
phi /= (q - 1);
return sqrt(phi);
}
};
template <typename Data>
class AggregateFunctionCramersV : public
IAggregateFunctionDataHelper<
Data,
AggregateFunctionCramersV<Data>
>
{
public:
AggregateFunctionCramersV(
const DataTypes & arguments
):
IAggregateFunctionDataHelper<
Data,
AggregateFunctionCramersV<Data>
> ({arguments}, {})
{
// notice: arguments has been in factory
}
String getName() const override
{
return "CramersV";
}
bool allocatesMemoryInArena() const override { return false; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<Float64>>();
}
void add(
AggregateDataPtr __restrict place,
const IColumn ** columns,
size_t row_num,
Arena *
) const override
{
UInt64 hash1 = UniqVariadicHash<false, false>::apply(1, columns, row_num);
UInt64 hash2 = UniqVariadicHash<false, false>::apply(1, columns + 1, row_num);
this->data(place).add(hash1, hash2);
}
void merge(
AggregateDataPtr __restrict place,
ConstAggregateDataPtr rhs, Arena *
) const override
{
this->data(place).merge(this->data(rhs));
}
void serialize(
ConstAggregateDataPtr __restrict place,
WriteBuffer & buf
) const override
{
this->data(place).serialize(buf);
}
void deserialize(
AggregateDataPtr __restrict place,
ReadBuffer & buf, Arena *
) const override
{
this->data(place).deserialize(buf);
}
void insertResultInto(
AggregateDataPtr __restrict place,
IColumn & to,
Arena *
) const override
{
Float64 result = this->data(place).get_result();
// std::cerr << "cur_size" << this->data(place).cur_size << '\n';
// std::cerr << "n_i size" << this->data(place).n_i.size() << '\n';
// std::cerr << "n_j size" << this->data(place).n_j.size() << '\n';
// std::cerr << "pair size " << this->data(place).pairs.size() << '\n';
// std::cerr << "result " << result << '\n';
auto & column = static_cast<ColumnVector<Float64> &>(to);
column.getData().push_back(result);
}
};
}

View File

@ -0,0 +1,59 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/CrossTab.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
#include <memory>
#include <cmath>
namespace DB
{
namespace
{
struct CramersVBiasCorrectedData : CrossTabData
{
static const char * getName()
{
return "cramersVBiasCorrected";
}
Float64 getResult() const
{
if (count < 2)
return std::numeric_limits<Float64>::quiet_NaN();
Float64 phi = 0.0;
for (const auto & [key, value_ab] : count_ab)
{
Float64 value_a = count_a.at(key.items[0]);
Float64 value_b = count_b.at(key.items[1]);
phi += value_ab * value_ab / (value_a * value_b) * count - 2 * value_ab + (value_a * value_b) / count;
}
phi /= count;
Float64 res = std::max(0.0, phi - (static_cast<Float64>(count_a.size()) - 1) * (static_cast<Float64>(count_b.size()) - 1) / (count - 1));
Float64 correction_a = count_a.size() - (static_cast<Float64>(count_a.size()) - 1) * (static_cast<Float64>(count_a.size()) - 1) / (count - 1);
Float64 correction_b = count_b.size() - (static_cast<Float64>(count_b.size()) - 1) * (static_cast<Float64>(count_b.size()) - 1) / (count - 1);
res /= std::min(correction_a, correction_b) - 1;
return sqrt(res);
}
};
}
void registerAggregateFunctionCramersVBiasCorrected(AggregateFunctionFactory & factory)
{
factory.registerFunction(CramersVBiasCorrectedData::getName(),
[](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
return std::make_shared<AggregateFunctionCrossTab<CramersVBiasCorrectedData>>(argument_types);
});
}
}

View File

@ -1,67 +0,0 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionCramersV.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
#include "registerAggregateFunctions.h"
#include <memory>
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
namespace DB
{
namespace
{
struct BiasCorrectionData : public AggregateFunctionCramersVData
{
Float64 get_result() const
{
if (cur_size < 2){
throw Exception("Aggregate function cramer's v bias corrected at least 2 values in columns", ErrorCodes::BAD_ARGUMENTS);
}
Float64 phi = 0.0;
for (const auto & cell : pairs) {
UInt128 hash_pair = cell.getKey();
UInt64 count_of_pair_tmp = cell.getMapped();
Float64 count_of_pair = Float64(count_of_pair_tmp);
UInt64 hash1 = (hash_pair << 64 >> 64);
UInt64 hash2 = (hash_pair >> 64);
UInt64 count1_tmp = n_i.find(hash1)->getMapped();
UInt64 count2_tmp = n_j.find(hash2)->getMapped();
Float64 count1 = static_cast<Float64>(count1_tmp);
Float64 count2 = Float64(count2_tmp);
phi += ((count_of_pair * count_of_pair / (count1 * count2) * cur_size)
- 2 * count_of_pair + (count1 * count2 / cur_size));
}
phi /= cur_size;
Float64 answ = std::max(0.0, phi - ((static_cast<Float64>(n_i.size()) - 1) * (static_cast<Float64>(n_j.size()) - 1) / (cur_size - 1)));
Float64 k = n_i.size() - (static_cast<Float64>(n_i.size()) - 1) * (static_cast<Float64>(n_i.size()) - 1) / (cur_size - 1);
Float64 r = n_j.size() - (static_cast<Float64>(n_j.size()) - 1) * (static_cast<Float64>(n_j.size()) - 1) / (cur_size - 1);
Float64 q = std::min(k, r);
answ /= (q - 1);
return sqrt(answ);
}
};
AggregateFunctionPtr createAggregateFunctionCramersVBiasCorrection(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
return std::make_shared<AggregateFunctionCramersV<BiasCorrectionData>>(argument_types);
}
}
void registerAggregateFunctionCramersVBiasCorrection(AggregateFunctionFactory & factory)
{
factory.registerFunction("CramersVBiasCorrection", createAggregateFunctionCramersVBiasCorrection);
}
}

View File

@ -1,48 +1,43 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionCramersV.h>
#include <AggregateFunctions/CrossTab.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
#include "registerAggregateFunctions.h"
#include <memory>
#include <cmath>
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
namespace DB
{
namespace
{
struct TheilsUData : public AggregateFunctionCramersVData
struct TheilsUData : CrossTabData
{
Float64 get_result() const
static const char * getName()
{
if (cur_size < 2){
throw Exception("Aggregate function theil's u requires at least 2 values in columns", ErrorCodes::BAD_ARGUMENTS);
}
Float64 h_x = 0.0;
for (const auto & cell : n_i) {
UInt64 count_x_tmp = cell.getMapped();
Float64 count_x = Float64(count_x_tmp);
h_x += (count_x / cur_size) * (log(count_x / cur_size));
}
return "theilsU";
}
Float64 getResult() const
{
if (count < 2)
return std::numeric_limits<Float64>::quiet_NaN();
Float64 h_a = 0.0;
for (const auto & [key, value] : count_a)
{
Float64 value_float = value;
h_a += (value_float / count) * log(value_float / count);
}
Float64 dep = 0.0;
for (const auto & cell : pairs) {
UInt128 hash_pair = cell.getKey();
UInt64 count_of_pair_tmp = cell.getMapped();
Float64 count_of_pair = Float64(count_of_pair_tmp);
for (const auto & [key, value] : count_ab)
{
Float64 value_ab = value;
Float64 value_b = count_b.at(key.items[1]);
UInt64 hash2 = (hash_pair >> 64);
UInt64 count2_tmp = n_j.find(hash2)->getMapped();
Float64 count2 = Float64 (count2_tmp);
dep += (count_of_pair / cur_size) * log(count_of_pair / count2);
dep += (value_ab / count) * log(value_ab / value_b);
}
dep -= h_x;
@ -51,18 +46,16 @@ namespace
}
};
AggregateFunctionPtr createAggregateFunctionTheilsU(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
return std::make_shared<AggregateFunctionCramersV<TheilsUData>>(argument_types);
}
}
void registerAggregateFunctionTheilsU(AggregateFunctionFactory & factory)
{
factory.registerFunction("TheilsU", createAggregateFunctionTheilsU);
factory.registerFunction(TheilsUData::getName(),
[](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertNoParameters(name, parameters);
return std::make_shared<AggregateFunctionCrossTab<TheilsUData>>(argument_types);
});
}
}
}

View File

@ -0,0 +1,134 @@
#pragma once
#include <AggregateFunctions/IAggregateFunction.h>
#include <Common/assert_cast.h>
#include <DataTypes/DataTypesNumber.h>
#include <Common/HashTable/HashMap.h>
#include <AggregateFunctions/UniqVariadicHash.h>
/** Aggregate function that calculates statistics on top of cross-tab:
* - histogram of every argument and every pair of elements.
* These statistics include:
* - Cramer's V;
* - Theil's U;
* - contingency coefficient;
* It can be interpreted as interdependency coefficient between arguments;
* or non-parametric correlation coefficient.
*/
namespace DB
{
struct CrossTabData
{
/// Total count.
UInt64 count = 0;
/// Count of every value of the first and second argument (values are pre-hashed).
/// Note: non-cryptographic 64bit hash is used, it means that the calculation is approximate.
HashMapWithStackMemory<UInt64, UInt64, TrivialHash, 4> count_a;
HashMapWithStackMemory<UInt64, UInt64, TrivialHash, 4> count_b;
/// Count of every pair of values. We pack two hashes into UInt128.
HashMapWithStackMemory<UInt128, UInt64, UInt128Hash, 4> count_ab;
void add(UInt64 hash1, UInt64 hash2)
{
++count;
++count_a[hash1];
++count_b[hash2];
UInt128 hash_pair{hash1, hash2};
++count_ab[hash_pair];
}
void merge(const CrossTabData & other)
{
count += other.count;
for (const auto & [key, value] : other.count_a)
count_a[key] += value;
for (const auto & [key, value] : other.count_b)
count_b[key] += value;
for (const auto & [key, value] : other.count_ab)
count_ab[key] += value;
}
void serialize(WriteBuffer &buf) const
{
writeBinary(count, buf);
count_a.write(buf);
count_b.write(buf);
count_ab.write(buf);
}
void deserialize(ReadBuffer &buf)
{
readBinary(count, buf);
count_a.read(buf);
count_b.read(buf);
count_ab.read(buf);
}
};
template <typename Data>
class AggregateFunctionCrossTab : public IAggregateFunctionDataHelper<Data, AggregateFunctionCrossTab<Data>>
{
public:
AggregateFunctionCrossTab(const DataTypes & arguments)
: IAggregateFunctionDataHelper<Data, AggregateFunctionCrossTab<Data>>({arguments}, {})
{
}
String getName() const override
{
return Data::getName();
}
bool allocatesMemoryInArena() const override
{
return false;
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNumber<Float64>>();
}
void add(
AggregateDataPtr __restrict place,
const IColumn ** columns,
size_t row_num,
Arena *) const override
{
UInt64 hash1 = UniqVariadicHash<false, false>::apply(1, &columns[0], row_num);
UInt64 hash2 = UniqVariadicHash<false, false>::apply(1, &columns[1], row_num);
this->data(place).add(hash1, hash2);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf) const override
{
this->data(place).serialize(buf);
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, Arena *) const override
{
this->data(place).deserialize(buf);
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
Float64 result = this->data(place).getResult();
auto & column = static_cast<ColumnVector<Float64> &>(to);
column.getData().push_back(result);
}
};
}

View File

@ -50,8 +50,8 @@ void registerAggregateFunctionWelchTTest(AggregateFunctionFactory &);
void registerAggregateFunctionStudentTTest(AggregateFunctionFactory &);
void registerAggregateFunctionCramersV(AggregateFunctionFactory &);
void registerAggregateFunctionTheilsU(AggregateFunctionFactory &);
void registerAggregateFunctionContingencyCoefficient(AggregateFunctionFactory &);
void registerAggregateFunctionCramersVBiasCorrection(AggregateFunctionFactory &);
void registerAggregateFunctionContingency(AggregateFunctionFactory &);
void registerAggregateFunctionCramersVBiasCorrected(AggregateFunctionFactory &);
void registerAggregateFunctionSingleValueOrNull(AggregateFunctionFactory &);
void registerAggregateFunctionSequenceNextNode(AggregateFunctionFactory &);
void registerAggregateFunctionExponentialMovingAverage(AggregateFunctionFactory &);
@ -105,8 +105,8 @@ void registerAggregateFunctions()
registerAggregateFunctionsBitwise(factory);
registerAggregateFunctionCramersV(factory);
registerAggregateFunctionTheilsU(factory);
registerAggregateFunctionContingencyCoefficient(factory);
registerAggregateFunctionCramersVBiasCorrection(factory);
registerAggregateFunctionContingency(factory);
registerAggregateFunctionCramersVBiasCorrected(factory);
registerAggregateFunctionsBitmap(factory);
registerAggregateFunctionsMaxIntersections(factory);
registerAggregateFunctionHistogram(factory);

View File

@ -262,6 +262,13 @@ public:
return it->getMapped();
}
typename Cell::Mapped & ALWAYS_INLINE at(const Key & x) const
{
if (auto it = this->find(x); it != this->end())
return it->getMapped();
throw DB::Exception("Cannot find element in HashMap::at method", DB::ErrorCodes::LOGICAL_ERROR);
}
};
namespace std