mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-28 20:50:49 +00:00
Merging contingency coefficients
This commit is contained in:
parent
9dc66e1e72
commit
4a094c2efd
@ -1,61 +1,53 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionCramersV.h>
|
||||
#include <AggregateFunctions/CrossTab.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include "registerAggregateFunctions.h"
|
||||
#include <memory>
|
||||
#include <cmath>
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
|
||||
struct ContingencyData : public AggregateFunctionCramersVData
|
||||
struct ContingencyData : CrossTabData
|
||||
{
|
||||
Float64 get_result() const
|
||||
static const char * getName()
|
||||
{
|
||||
if (cur_size < 2){
|
||||
throw Exception("Aggregate function contingency coefficient requires at least 2 values in columns", ErrorCodes::BAD_ARGUMENTS);
|
||||
return "contingency";
|
||||
}
|
||||
|
||||
Float64 getResult() const
|
||||
{
|
||||
if (count < 2)
|
||||
return std::numeric_limits<Float64>::quiet_NaN();
|
||||
|
||||
Float64 phi = 0.0;
|
||||
for (const auto & cell : pairs) {
|
||||
UInt128 hash_pair = cell.getKey();
|
||||
UInt64 count_of_pair_tmp = cell.getMapped();
|
||||
Float64 count_of_pair = Float64(count_of_pair_tmp);
|
||||
UInt64 hash1 = (hash_pair << 64 >> 64);
|
||||
UInt64 hash2 = (hash_pair >> 64);
|
||||
for (const auto & [key, value_ab] : count_ab)
|
||||
{
|
||||
Float64 value_a = count_a.at(key.items[0]);
|
||||
Float64 value_b = count_b.at(key.items[1]);
|
||||
|
||||
UInt64 count1_tmp = n_i.find(hash1)->getMapped();
|
||||
UInt64 count2_tmp = n_j.find(hash2)->getMapped();
|
||||
Float64 count1 = static_cast<Float64>(count1_tmp);
|
||||
Float64 count2 = Float64(count2_tmp);
|
||||
|
||||
phi += ((count_of_pair * count_of_pair / (count1 * count2) * cur_size)
|
||||
- 2 * count_of_pair + (count1 * count2 / cur_size));
|
||||
phi += value_ab * value_ab / (value_a * value_b) * count - 2 * value_ab + (value_a * value_b) / count;
|
||||
}
|
||||
phi /= cur_size;
|
||||
return sqrt(phi / (phi + cur_size));
|
||||
phi /= count;
|
||||
|
||||
return sqrt(phi / (phi + count));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionContingencyCoefficient(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
void registerAggregateFunctionContingency(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction(ContingencyData::getName(),
|
||||
[](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
return std::make_shared<AggregateFunctionCramersV<ContingencyData>>(argument_types);
|
||||
return std::make_shared<AggregateFunctionCrossTab<ContingencyData>>(argument_types);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionContingencyCoefficient(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("ContingencyCoefficient", createAggregateFunctionContingencyCoefficient);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,27 +1,56 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionCramersV.h>
|
||||
#include <AggregateFunctions/CrossTab.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include "registerAggregateFunctions.h"
|
||||
#include <memory>
|
||||
#include <cmath>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionCramersV(const std::string & name, const DataTypes & argument_types,
|
||||
const Array & parameters, const Settings *)
|
||||
struct CramersVData : CrossTabData
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
return std::make_shared<AggregateFunctionCramersV<AggregateFunctionCramersVData>>(argument_types);
|
||||
static const char * getName()
|
||||
{
|
||||
return "cramersV";
|
||||
}
|
||||
|
||||
Float64 getResult() const
|
||||
{
|
||||
if (count < 2)
|
||||
return std::numeric_limits<Float64>::quiet_NaN();
|
||||
|
||||
Float64 phi = 0.0;
|
||||
for (const auto & [key, value_ab] : count_ab)
|
||||
{
|
||||
Float64 value_a = count_a.at(key.items[0]);
|
||||
Float64 value_b = count_b.at(key.items[1]);
|
||||
|
||||
phi += value_ab * value_ab / (value_a * value_b) * count - 2 * value_ab + (value_a * value_b) / count;
|
||||
}
|
||||
|
||||
phi /= count;
|
||||
UInt64 q = std::min(count_a.size(), count_b.size());
|
||||
phi /= q - 1;
|
||||
|
||||
return sqrt(phi);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionCramersV(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("CramersV", createAggregateFunctionCramersV);
|
||||
factory.registerFunction(CramersVData::getName(),
|
||||
[](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
return std::make_shared<AggregateFunctionCrossTab<CramersVData>>(argument_types);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,207 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Common/FieldVisitors.h>
|
||||
#include <Core/Types.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Common/HashTable/HashMap.h>
|
||||
#include <AggregateFunctions/UniqVariadicHash.h>
|
||||
#include <limits>
|
||||
#include <cmath>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
|
||||
struct AggregateFunctionCramersVData
|
||||
{
|
||||
size_t cur_size = 0;
|
||||
HashMap<UInt64, UInt64> n_i;
|
||||
HashMap<UInt64, UInt64> n_j;
|
||||
HashMap<UInt128, UInt64> pairs;
|
||||
|
||||
|
||||
void add(UInt64 hash1, UInt64 hash2)
|
||||
{
|
||||
cur_size += 1;
|
||||
n_i[hash1] += 1;
|
||||
n_j[hash2] += 1;
|
||||
|
||||
UInt128 hash_pair = hash1 | (static_cast<UInt128>(hash2) << 64);
|
||||
pairs[hash_pair] += 1;
|
||||
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionCramersVData &other)
|
||||
{
|
||||
cur_size += other.cur_size;
|
||||
|
||||
for (const auto& pair : other.n_i) {
|
||||
UInt64 hash1 = pair.getKey();
|
||||
UInt64 count = pair.getMapped();
|
||||
n_i[hash1] += count;
|
||||
}
|
||||
for (const auto& pair : other.n_j) {
|
||||
UInt64 hash1 = pair.getKey();
|
||||
UInt64 count = pair.getMapped();
|
||||
n_j[hash1] += count;
|
||||
}
|
||||
for (const auto& pair : other.pairs) {
|
||||
UInt128 hash1 = pair.getKey();
|
||||
UInt64 count = pair.getMapped();
|
||||
pairs[hash1] += count;
|
||||
}
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer &buf) const
|
||||
{
|
||||
writeBinary(cur_size, buf);
|
||||
n_i.write(buf);
|
||||
n_j.write(buf);
|
||||
pairs.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer &buf)
|
||||
{
|
||||
readBinary(cur_size, buf);
|
||||
n_i.read(buf);
|
||||
n_j.read(buf);
|
||||
pairs.read(buf);
|
||||
}
|
||||
|
||||
Float64 get_result() const
|
||||
{
|
||||
if (cur_size < 2){
|
||||
throw Exception("Aggregate function cramer's v requires et least 2 values in columns", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
Float64 phi = 0.0;
|
||||
for (const auto & cell : pairs) {
|
||||
UInt128 hash_pair = cell.getKey();
|
||||
UInt64 count_of_pair_tmp = cell.getMapped();
|
||||
Float64 count_of_pair = Float64(count_of_pair_tmp);
|
||||
UInt64 hash1 = (hash_pair << 64 >> 64);
|
||||
UInt64 hash2 = (hash_pair >> 64);
|
||||
|
||||
UInt64 count1_tmp = n_i.find(hash1)->getMapped();
|
||||
UInt64 count2_tmp = n_j.find(hash2)->getMapped();
|
||||
Float64 count1 = static_cast<Float64>(count1_tmp);
|
||||
Float64 count2 = Float64(count2_tmp);
|
||||
|
||||
phi += ((count_of_pair * count_of_pair / (count1 * count2) * cur_size)
|
||||
- 2 * count_of_pair + (count1 * count2 / cur_size));
|
||||
}
|
||||
phi /= cur_size;
|
||||
|
||||
UInt64 q = std::min(n_i.size(), n_j.size());
|
||||
phi /= (q - 1);
|
||||
return sqrt(phi);
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Data>
|
||||
|
||||
class AggregateFunctionCramersV : public
|
||||
IAggregateFunctionDataHelper<
|
||||
Data,
|
||||
AggregateFunctionCramersV<Data>
|
||||
>
|
||||
{
|
||||
|
||||
public:
|
||||
AggregateFunctionCramersV(
|
||||
const DataTypes & arguments
|
||||
):
|
||||
IAggregateFunctionDataHelper<
|
||||
Data,
|
||||
AggregateFunctionCramersV<Data>
|
||||
> ({arguments}, {})
|
||||
{
|
||||
// notice: arguments has been in factory
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "CramersV";
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return false; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
|
||||
void add(
|
||||
AggregateDataPtr __restrict place,
|
||||
const IColumn ** columns,
|
||||
size_t row_num,
|
||||
Arena *
|
||||
) const override
|
||||
{
|
||||
UInt64 hash1 = UniqVariadicHash<false, false>::apply(1, columns, row_num);
|
||||
UInt64 hash2 = UniqVariadicHash<false, false>::apply(1, columns + 1, row_num);
|
||||
|
||||
this->data(place).add(hash1, hash2);
|
||||
}
|
||||
|
||||
void merge(
|
||||
AggregateDataPtr __restrict place,
|
||||
ConstAggregateDataPtr rhs, Arena *
|
||||
) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(
|
||||
ConstAggregateDataPtr __restrict place,
|
||||
WriteBuffer & buf
|
||||
) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(
|
||||
AggregateDataPtr __restrict place,
|
||||
ReadBuffer & buf, Arena *
|
||||
) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(
|
||||
AggregateDataPtr __restrict place,
|
||||
IColumn & to,
|
||||
Arena *
|
||||
) const override
|
||||
{
|
||||
Float64 result = this->data(place).get_result();
|
||||
// std::cerr << "cur_size" << this->data(place).cur_size << '\n';
|
||||
// std::cerr << "n_i size" << this->data(place).n_i.size() << '\n';
|
||||
// std::cerr << "n_j size" << this->data(place).n_j.size() << '\n';
|
||||
// std::cerr << "pair size " << this->data(place).pairs.size() << '\n';
|
||||
// std::cerr << "result " << result << '\n';
|
||||
|
||||
auto & column = static_cast<ColumnVector<Float64> &>(to);
|
||||
column.getData().push_back(result);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
@ -0,0 +1,59 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/CrossTab.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <memory>
|
||||
#include <cmath>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
struct CramersVBiasCorrectedData : CrossTabData
|
||||
{
|
||||
static const char * getName()
|
||||
{
|
||||
return "cramersVBiasCorrected";
|
||||
}
|
||||
|
||||
Float64 getResult() const
|
||||
{
|
||||
if (count < 2)
|
||||
return std::numeric_limits<Float64>::quiet_NaN();
|
||||
|
||||
Float64 phi = 0.0;
|
||||
for (const auto & [key, value_ab] : count_ab)
|
||||
{
|
||||
Float64 value_a = count_a.at(key.items[0]);
|
||||
Float64 value_b = count_b.at(key.items[1]);
|
||||
|
||||
phi += value_ab * value_ab / (value_a * value_b) * count - 2 * value_ab + (value_a * value_b) / count;
|
||||
}
|
||||
|
||||
phi /= count;
|
||||
|
||||
Float64 res = std::max(0.0, phi - (static_cast<Float64>(count_a.size()) - 1) * (static_cast<Float64>(count_b.size()) - 1) / (count - 1));
|
||||
Float64 correction_a = count_a.size() - (static_cast<Float64>(count_a.size()) - 1) * (static_cast<Float64>(count_a.size()) - 1) / (count - 1);
|
||||
Float64 correction_b = count_b.size() - (static_cast<Float64>(count_b.size()) - 1) * (static_cast<Float64>(count_b.size()) - 1) / (count - 1);
|
||||
res /= std::min(correction_a, correction_b) - 1;
|
||||
|
||||
return sqrt(res);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionCramersVBiasCorrected(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction(CramersVBiasCorrectedData::getName(),
|
||||
[](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
return std::make_shared<AggregateFunctionCrossTab<CramersVBiasCorrectedData>>(argument_types);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
@ -1,67 +0,0 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionCramersV.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include "registerAggregateFunctions.h"
|
||||
#include <memory>
|
||||
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace
|
||||
{
|
||||
|
||||
|
||||
struct BiasCorrectionData : public AggregateFunctionCramersVData
|
||||
{
|
||||
Float64 get_result() const
|
||||
{
|
||||
if (cur_size < 2){
|
||||
throw Exception("Aggregate function cramer's v bias corrected at least 2 values in columns", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
Float64 phi = 0.0;
|
||||
for (const auto & cell : pairs) {
|
||||
UInt128 hash_pair = cell.getKey();
|
||||
UInt64 count_of_pair_tmp = cell.getMapped();
|
||||
Float64 count_of_pair = Float64(count_of_pair_tmp);
|
||||
UInt64 hash1 = (hash_pair << 64 >> 64);
|
||||
UInt64 hash2 = (hash_pair >> 64);
|
||||
|
||||
UInt64 count1_tmp = n_i.find(hash1)->getMapped();
|
||||
UInt64 count2_tmp = n_j.find(hash2)->getMapped();
|
||||
Float64 count1 = static_cast<Float64>(count1_tmp);
|
||||
Float64 count2 = Float64(count2_tmp);
|
||||
|
||||
phi += ((count_of_pair * count_of_pair / (count1 * count2) * cur_size)
|
||||
- 2 * count_of_pair + (count1 * count2 / cur_size));
|
||||
}
|
||||
phi /= cur_size;
|
||||
Float64 answ = std::max(0.0, phi - ((static_cast<Float64>(n_i.size()) - 1) * (static_cast<Float64>(n_j.size()) - 1) / (cur_size - 1)));
|
||||
Float64 k = n_i.size() - (static_cast<Float64>(n_i.size()) - 1) * (static_cast<Float64>(n_i.size()) - 1) / (cur_size - 1);
|
||||
Float64 r = n_j.size() - (static_cast<Float64>(n_j.size()) - 1) * (static_cast<Float64>(n_j.size()) - 1) / (cur_size - 1);
|
||||
Float64 q = std::min(k, r);
|
||||
answ /= (q - 1);
|
||||
return sqrt(answ);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionCramersVBiasCorrection(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
return std::make_shared<AggregateFunctionCramersV<BiasCorrectionData>>(argument_types);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionCramersVBiasCorrection(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("CramersVBiasCorrection", createAggregateFunctionCramersVBiasCorrection);
|
||||
}
|
||||
|
||||
}
|
@ -1,48 +1,43 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionCramersV.h>
|
||||
#include <AggregateFunctions/CrossTab.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include "registerAggregateFunctions.h"
|
||||
#include <memory>
|
||||
#include <cmath>
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
|
||||
struct TheilsUData : public AggregateFunctionCramersVData
|
||||
struct TheilsUData : CrossTabData
|
||||
{
|
||||
Float64 get_result() const
|
||||
static const char * getName()
|
||||
{
|
||||
if (cur_size < 2){
|
||||
throw Exception("Aggregate function theil's u requires at least 2 values in columns", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
Float64 h_x = 0.0;
|
||||
for (const auto & cell : n_i) {
|
||||
UInt64 count_x_tmp = cell.getMapped();
|
||||
Float64 count_x = Float64(count_x_tmp);
|
||||
h_x += (count_x / cur_size) * (log(count_x / cur_size));
|
||||
return "theilsU";
|
||||
}
|
||||
|
||||
Float64 getResult() const
|
||||
{
|
||||
if (count < 2)
|
||||
return std::numeric_limits<Float64>::quiet_NaN();
|
||||
|
||||
Float64 h_a = 0.0;
|
||||
for (const auto & [key, value] : count_a)
|
||||
{
|
||||
Float64 value_float = value;
|
||||
h_a += (value_float / count) * log(value_float / count);
|
||||
}
|
||||
|
||||
Float64 dep = 0.0;
|
||||
for (const auto & cell : pairs) {
|
||||
UInt128 hash_pair = cell.getKey();
|
||||
UInt64 count_of_pair_tmp = cell.getMapped();
|
||||
Float64 count_of_pair = Float64(count_of_pair_tmp);
|
||||
for (const auto & [key, value] : count_ab)
|
||||
{
|
||||
Float64 value_ab = value;
|
||||
Float64 value_b = count_b.at(key.items[1]);
|
||||
|
||||
UInt64 hash2 = (hash_pair >> 64);
|
||||
|
||||
UInt64 count2_tmp = n_j.find(hash2)->getMapped();
|
||||
Float64 count2 = Float64 (count2_tmp);
|
||||
|
||||
dep += (count_of_pair / cur_size) * log(count_of_pair / count2);
|
||||
dep += (value_ab / count) * log(value_ab / value_b);
|
||||
}
|
||||
|
||||
dep -= h_x;
|
||||
@ -51,18 +46,16 @@ namespace
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionTheilsU(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
return std::make_shared<AggregateFunctionCramersV<TheilsUData>>(argument_types);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionTheilsU(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("TheilsU", createAggregateFunctionTheilsU);
|
||||
factory.registerFunction(TheilsUData::getName(),
|
||||
[](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertNoParameters(name, parameters);
|
||||
return std::make_shared<AggregateFunctionCrossTab<TheilsUData>>(argument_types);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
134
src/AggregateFunctions/CrossTab.h
Normal file
134
src/AggregateFunctions/CrossTab.h
Normal file
@ -0,0 +1,134 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Common/HashTable/HashMap.h>
|
||||
#include <AggregateFunctions/UniqVariadicHash.h>
|
||||
|
||||
|
||||
/** Aggregate function that calculates statistics on top of cross-tab:
|
||||
* - histogram of every argument and every pair of elements.
|
||||
* These statistics include:
|
||||
* - Cramer's V;
|
||||
* - Theil's U;
|
||||
* - contingency coefficient;
|
||||
* It can be interpreted as interdependency coefficient between arguments;
|
||||
* or non-parametric correlation coefficient.
|
||||
*/
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct CrossTabData
|
||||
{
|
||||
/// Total count.
|
||||
UInt64 count = 0;
|
||||
|
||||
/// Count of every value of the first and second argument (values are pre-hashed).
|
||||
/// Note: non-cryptographic 64bit hash is used, it means that the calculation is approximate.
|
||||
HashMapWithStackMemory<UInt64, UInt64, TrivialHash, 4> count_a;
|
||||
HashMapWithStackMemory<UInt64, UInt64, TrivialHash, 4> count_b;
|
||||
|
||||
/// Count of every pair of values. We pack two hashes into UInt128.
|
||||
HashMapWithStackMemory<UInt128, UInt64, UInt128Hash, 4> count_ab;
|
||||
|
||||
|
||||
void add(UInt64 hash1, UInt64 hash2)
|
||||
{
|
||||
++count;
|
||||
++count_a[hash1];
|
||||
++count_b[hash2];
|
||||
|
||||
UInt128 hash_pair{hash1, hash2};
|
||||
++count_ab[hash_pair];
|
||||
}
|
||||
|
||||
void merge(const CrossTabData & other)
|
||||
{
|
||||
count += other.count;
|
||||
for (const auto & [key, value] : other.count_a)
|
||||
count_a[key] += value;
|
||||
for (const auto & [key, value] : other.count_b)
|
||||
count_b[key] += value;
|
||||
for (const auto & [key, value] : other.count_ab)
|
||||
count_ab[key] += value;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer &buf) const
|
||||
{
|
||||
writeBinary(count, buf);
|
||||
count_a.write(buf);
|
||||
count_b.write(buf);
|
||||
count_ab.write(buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer &buf)
|
||||
{
|
||||
readBinary(count, buf);
|
||||
count_a.read(buf);
|
||||
count_b.read(buf);
|
||||
count_ab.read(buf);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename Data>
|
||||
class AggregateFunctionCrossTab : public IAggregateFunctionDataHelper<Data, AggregateFunctionCrossTab<Data>>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionCrossTab(const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<Data, AggregateFunctionCrossTab<Data>>({arguments}, {})
|
||||
{
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return Data::getName();
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
|
||||
void add(
|
||||
AggregateDataPtr __restrict place,
|
||||
const IColumn ** columns,
|
||||
size_t row_num,
|
||||
Arena *) const override
|
||||
{
|
||||
UInt64 hash1 = UniqVariadicHash<false, false>::apply(1, &columns[0], row_num);
|
||||
UInt64 hash2 = UniqVariadicHash<false, false>::apply(1, &columns[1], row_num);
|
||||
|
||||
this->data(place).add(hash1, hash2);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||
{
|
||||
Float64 result = this->data(place).getResult();
|
||||
auto & column = static_cast<ColumnVector<Float64> &>(to);
|
||||
column.getData().push_back(result);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -50,8 +50,8 @@ void registerAggregateFunctionWelchTTest(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionStudentTTest(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionCramersV(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionTheilsU(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionContingencyCoefficient(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionCramersVBiasCorrection(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionContingency(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionCramersVBiasCorrected(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionSingleValueOrNull(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionSequenceNextNode(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionExponentialMovingAverage(AggregateFunctionFactory &);
|
||||
@ -105,8 +105,8 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionsBitwise(factory);
|
||||
registerAggregateFunctionCramersV(factory);
|
||||
registerAggregateFunctionTheilsU(factory);
|
||||
registerAggregateFunctionContingencyCoefficient(factory);
|
||||
registerAggregateFunctionCramersVBiasCorrection(factory);
|
||||
registerAggregateFunctionContingency(factory);
|
||||
registerAggregateFunctionCramersVBiasCorrected(factory);
|
||||
registerAggregateFunctionsBitmap(factory);
|
||||
registerAggregateFunctionsMaxIntersections(factory);
|
||||
registerAggregateFunctionHistogram(factory);
|
||||
|
@ -262,6 +262,13 @@ public:
|
||||
|
||||
return it->getMapped();
|
||||
}
|
||||
|
||||
typename Cell::Mapped & ALWAYS_INLINE at(const Key & x) const
|
||||
{
|
||||
if (auto it = this->find(x); it != this->end())
|
||||
return it->getMapped();
|
||||
throw DB::Exception("Cannot find element in HashMap::at method", DB::ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
};
|
||||
|
||||
namespace std
|
||||
|
Loading…
Reference in New Issue
Block a user