mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-29 05:00:47 +00:00
Merging contingency coefficients
This commit is contained in:
parent
9dc66e1e72
commit
4a094c2efd
@ -1,61 +1,53 @@
|
|||||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||||
#include <AggregateFunctions/AggregateFunctionCramersV.h>
|
#include <AggregateFunctions/CrossTab.h>
|
||||||
#include <AggregateFunctions/FactoryHelpers.h>
|
#include <AggregateFunctions/FactoryHelpers.h>
|
||||||
#include <AggregateFunctions/Helpers.h>
|
#include <AggregateFunctions/Helpers.h>
|
||||||
#include "registerAggregateFunctions.h"
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
namespace ErrorCodes
|
|
||||||
{
|
|
||||||
extern const int BAD_ARGUMENTS;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
{
|
{
|
||||||
|
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
|
|
||||||
|
struct ContingencyData : CrossTabData
|
||||||
struct ContingencyData : public AggregateFunctionCramersVData
|
|
||||||
{
|
{
|
||||||
Float64 get_result() const
|
static const char * getName()
|
||||||
{
|
{
|
||||||
if (cur_size < 2){
|
return "contingency";
|
||||||
throw Exception("Aggregate function contingency coefficient requires at least 2 values in columns", ErrorCodes::BAD_ARGUMENTS);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Float64 getResult() const
|
||||||
|
{
|
||||||
|
if (count < 2)
|
||||||
|
return std::numeric_limits<Float64>::quiet_NaN();
|
||||||
|
|
||||||
Float64 phi = 0.0;
|
Float64 phi = 0.0;
|
||||||
for (const auto & cell : pairs) {
|
for (const auto & [key, value_ab] : count_ab)
|
||||||
UInt128 hash_pair = cell.getKey();
|
{
|
||||||
UInt64 count_of_pair_tmp = cell.getMapped();
|
Float64 value_a = count_a.at(key.items[0]);
|
||||||
Float64 count_of_pair = Float64(count_of_pair_tmp);
|
Float64 value_b = count_b.at(key.items[1]);
|
||||||
UInt64 hash1 = (hash_pair << 64 >> 64);
|
|
||||||
UInt64 hash2 = (hash_pair >> 64);
|
|
||||||
|
|
||||||
UInt64 count1_tmp = n_i.find(hash1)->getMapped();
|
phi += value_ab * value_ab / (value_a * value_b) * count - 2 * value_ab + (value_a * value_b) / count;
|
||||||
UInt64 count2_tmp = n_j.find(hash2)->getMapped();
|
|
||||||
Float64 count1 = static_cast<Float64>(count1_tmp);
|
|
||||||
Float64 count2 = Float64(count2_tmp);
|
|
||||||
|
|
||||||
phi += ((count_of_pair * count_of_pair / (count1 * count2) * cur_size)
|
|
||||||
- 2 * count_of_pair + (count1 * count2 / cur_size));
|
|
||||||
}
|
}
|
||||||
phi /= cur_size;
|
phi /= count;
|
||||||
return sqrt(phi / (phi + cur_size));
|
|
||||||
|
return sqrt(phi / (phi + count));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void registerAggregateFunctionContingency(AggregateFunctionFactory & factory)
|
||||||
AggregateFunctionPtr createAggregateFunctionContingencyCoefficient(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
{
|
||||||
|
factory.registerFunction(ContingencyData::getName(),
|
||||||
|
[](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||||
{
|
{
|
||||||
assertNoParameters(name, parameters);
|
assertNoParameters(name, parameters);
|
||||||
return std::make_shared<AggregateFunctionCramersV<ContingencyData>>(argument_types);
|
return std::make_shared<AggregateFunctionCrossTab<ContingencyData>>(argument_types);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void registerAggregateFunctionContingencyCoefficient(AggregateFunctionFactory & factory)
|
|
||||||
{
|
|
||||||
factory.registerFunction("ContingencyCoefficient", createAggregateFunctionContingencyCoefficient);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1,27 +1,56 @@
|
|||||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||||
#include <AggregateFunctions/AggregateFunctionCramersV.h>
|
#include <AggregateFunctions/CrossTab.h>
|
||||||
#include <AggregateFunctions/FactoryHelpers.h>
|
#include <AggregateFunctions/FactoryHelpers.h>
|
||||||
#include <AggregateFunctions/Helpers.h>
|
#include <AggregateFunctions/Helpers.h>
|
||||||
#include "registerAggregateFunctions.h"
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
{
|
{
|
||||||
|
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
|
|
||||||
AggregateFunctionPtr createAggregateFunctionCramersV(const std::string & name, const DataTypes & argument_types,
|
struct CramersVData : CrossTabData
|
||||||
const Array & parameters, const Settings *)
|
|
||||||
{
|
{
|
||||||
assertNoParameters(name, parameters);
|
static const char * getName()
|
||||||
return std::make_shared<AggregateFunctionCramersV<AggregateFunctionCramersVData>>(argument_types);
|
{
|
||||||
|
return "cramersV";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Float64 getResult() const
|
||||||
|
{
|
||||||
|
if (count < 2)
|
||||||
|
return std::numeric_limits<Float64>::quiet_NaN();
|
||||||
|
|
||||||
|
Float64 phi = 0.0;
|
||||||
|
for (const auto & [key, value_ab] : count_ab)
|
||||||
|
{
|
||||||
|
Float64 value_a = count_a.at(key.items[0]);
|
||||||
|
Float64 value_b = count_b.at(key.items[1]);
|
||||||
|
|
||||||
|
phi += value_ab * value_ab / (value_a * value_b) * count - 2 * value_ab + (value_a * value_b) / count;
|
||||||
|
}
|
||||||
|
|
||||||
|
phi /= count;
|
||||||
|
UInt64 q = std::min(count_a.size(), count_b.size());
|
||||||
|
phi /= q - 1;
|
||||||
|
|
||||||
|
return sqrt(phi);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void registerAggregateFunctionCramersV(AggregateFunctionFactory & factory)
|
void registerAggregateFunctionCramersV(AggregateFunctionFactory & factory)
|
||||||
{
|
{
|
||||||
factory.registerFunction("CramersV", createAggregateFunctionCramersV);
|
factory.registerFunction(CramersVData::getName(),
|
||||||
|
[](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||||
|
{
|
||||||
|
assertNoParameters(name, parameters);
|
||||||
|
return std::make_shared<AggregateFunctionCrossTab<CramersVData>>(argument_types);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -1,207 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <AggregateFunctions/IAggregateFunction.h>
|
|
||||||
#include <Columns/ColumnVector.h>
|
|
||||||
#include <Columns/ColumnTuple.h>
|
|
||||||
#include <Common/assert_cast.h>
|
|
||||||
#include <Common/FieldVisitors.h>
|
|
||||||
#include <Core/Types.h>
|
|
||||||
#include <DataTypes/DataTypesDecimal.h>
|
|
||||||
#include <DataTypes/DataTypeNullable.h>
|
|
||||||
#include <DataTypes/DataTypesNumber.h>
|
|
||||||
#include <DataTypes/DataTypeTuple.h>
|
|
||||||
#include <IO/ReadHelpers.h>
|
|
||||||
#include <IO/WriteHelpers.h>
|
|
||||||
#include <Common/HashTable/HashMap.h>
|
|
||||||
#include <AggregateFunctions/UniqVariadicHash.h>
|
|
||||||
#include <limits>
|
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
|
|
||||||
namespace ErrorCodes
|
|
||||||
{
|
|
||||||
extern const int BAD_ARGUMENTS;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace DB
|
|
||||||
{
|
|
||||||
|
|
||||||
|
|
||||||
struct AggregateFunctionCramersVData
|
|
||||||
{
|
|
||||||
size_t cur_size = 0;
|
|
||||||
HashMap<UInt64, UInt64> n_i;
|
|
||||||
HashMap<UInt64, UInt64> n_j;
|
|
||||||
HashMap<UInt128, UInt64> pairs;
|
|
||||||
|
|
||||||
|
|
||||||
void add(UInt64 hash1, UInt64 hash2)
|
|
||||||
{
|
|
||||||
cur_size += 1;
|
|
||||||
n_i[hash1] += 1;
|
|
||||||
n_j[hash2] += 1;
|
|
||||||
|
|
||||||
UInt128 hash_pair = hash1 | (static_cast<UInt128>(hash2) << 64);
|
|
||||||
pairs[hash_pair] += 1;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
void merge(const AggregateFunctionCramersVData &other)
|
|
||||||
{
|
|
||||||
cur_size += other.cur_size;
|
|
||||||
|
|
||||||
for (const auto& pair : other.n_i) {
|
|
||||||
UInt64 hash1 = pair.getKey();
|
|
||||||
UInt64 count = pair.getMapped();
|
|
||||||
n_i[hash1] += count;
|
|
||||||
}
|
|
||||||
for (const auto& pair : other.n_j) {
|
|
||||||
UInt64 hash1 = pair.getKey();
|
|
||||||
UInt64 count = pair.getMapped();
|
|
||||||
n_j[hash1] += count;
|
|
||||||
}
|
|
||||||
for (const auto& pair : other.pairs) {
|
|
||||||
UInt128 hash1 = pair.getKey();
|
|
||||||
UInt64 count = pair.getMapped();
|
|
||||||
pairs[hash1] += count;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void serialize(WriteBuffer &buf) const
|
|
||||||
{
|
|
||||||
writeBinary(cur_size, buf);
|
|
||||||
n_i.write(buf);
|
|
||||||
n_j.write(buf);
|
|
||||||
pairs.write(buf);
|
|
||||||
}
|
|
||||||
|
|
||||||
void deserialize(ReadBuffer &buf)
|
|
||||||
{
|
|
||||||
readBinary(cur_size, buf);
|
|
||||||
n_i.read(buf);
|
|
||||||
n_j.read(buf);
|
|
||||||
pairs.read(buf);
|
|
||||||
}
|
|
||||||
|
|
||||||
Float64 get_result() const
|
|
||||||
{
|
|
||||||
if (cur_size < 2){
|
|
||||||
throw Exception("Aggregate function cramer's v requires et least 2 values in columns", ErrorCodes::BAD_ARGUMENTS);
|
|
||||||
}
|
|
||||||
Float64 phi = 0.0;
|
|
||||||
for (const auto & cell : pairs) {
|
|
||||||
UInt128 hash_pair = cell.getKey();
|
|
||||||
UInt64 count_of_pair_tmp = cell.getMapped();
|
|
||||||
Float64 count_of_pair = Float64(count_of_pair_tmp);
|
|
||||||
UInt64 hash1 = (hash_pair << 64 >> 64);
|
|
||||||
UInt64 hash2 = (hash_pair >> 64);
|
|
||||||
|
|
||||||
UInt64 count1_tmp = n_i.find(hash1)->getMapped();
|
|
||||||
UInt64 count2_tmp = n_j.find(hash2)->getMapped();
|
|
||||||
Float64 count1 = static_cast<Float64>(count1_tmp);
|
|
||||||
Float64 count2 = Float64(count2_tmp);
|
|
||||||
|
|
||||||
phi += ((count_of_pair * count_of_pair / (count1 * count2) * cur_size)
|
|
||||||
- 2 * count_of_pair + (count1 * count2 / cur_size));
|
|
||||||
}
|
|
||||||
phi /= cur_size;
|
|
||||||
|
|
||||||
UInt64 q = std::min(n_i.size(), n_j.size());
|
|
||||||
phi /= (q - 1);
|
|
||||||
return sqrt(phi);
|
|
||||||
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Data>
|
|
||||||
|
|
||||||
class AggregateFunctionCramersV : public
|
|
||||||
IAggregateFunctionDataHelper<
|
|
||||||
Data,
|
|
||||||
AggregateFunctionCramersV<Data>
|
|
||||||
>
|
|
||||||
{
|
|
||||||
|
|
||||||
public:
|
|
||||||
AggregateFunctionCramersV(
|
|
||||||
const DataTypes & arguments
|
|
||||||
):
|
|
||||||
IAggregateFunctionDataHelper<
|
|
||||||
Data,
|
|
||||||
AggregateFunctionCramersV<Data>
|
|
||||||
> ({arguments}, {})
|
|
||||||
{
|
|
||||||
// notice: arguments has been in factory
|
|
||||||
}
|
|
||||||
|
|
||||||
String getName() const override
|
|
||||||
{
|
|
||||||
return "CramersV";
|
|
||||||
}
|
|
||||||
|
|
||||||
bool allocatesMemoryInArena() const override { return false; }
|
|
||||||
|
|
||||||
DataTypePtr getReturnType() const override
|
|
||||||
{
|
|
||||||
return std::make_shared<DataTypeNumber<Float64>>();
|
|
||||||
}
|
|
||||||
|
|
||||||
void add(
|
|
||||||
AggregateDataPtr __restrict place,
|
|
||||||
const IColumn ** columns,
|
|
||||||
size_t row_num,
|
|
||||||
Arena *
|
|
||||||
) const override
|
|
||||||
{
|
|
||||||
UInt64 hash1 = UniqVariadicHash<false, false>::apply(1, columns, row_num);
|
|
||||||
UInt64 hash2 = UniqVariadicHash<false, false>::apply(1, columns + 1, row_num);
|
|
||||||
|
|
||||||
this->data(place).add(hash1, hash2);
|
|
||||||
}
|
|
||||||
|
|
||||||
void merge(
|
|
||||||
AggregateDataPtr __restrict place,
|
|
||||||
ConstAggregateDataPtr rhs, Arena *
|
|
||||||
) const override
|
|
||||||
{
|
|
||||||
this->data(place).merge(this->data(rhs));
|
|
||||||
}
|
|
||||||
|
|
||||||
void serialize(
|
|
||||||
ConstAggregateDataPtr __restrict place,
|
|
||||||
WriteBuffer & buf
|
|
||||||
) const override
|
|
||||||
{
|
|
||||||
this->data(place).serialize(buf);
|
|
||||||
}
|
|
||||||
|
|
||||||
void deserialize(
|
|
||||||
AggregateDataPtr __restrict place,
|
|
||||||
ReadBuffer & buf, Arena *
|
|
||||||
) const override
|
|
||||||
{
|
|
||||||
this->data(place).deserialize(buf);
|
|
||||||
}
|
|
||||||
|
|
||||||
void insertResultInto(
|
|
||||||
AggregateDataPtr __restrict place,
|
|
||||||
IColumn & to,
|
|
||||||
Arena *
|
|
||||||
) const override
|
|
||||||
{
|
|
||||||
Float64 result = this->data(place).get_result();
|
|
||||||
// std::cerr << "cur_size" << this->data(place).cur_size << '\n';
|
|
||||||
// std::cerr << "n_i size" << this->data(place).n_i.size() << '\n';
|
|
||||||
// std::cerr << "n_j size" << this->data(place).n_j.size() << '\n';
|
|
||||||
// std::cerr << "pair size " << this->data(place).pairs.size() << '\n';
|
|
||||||
// std::cerr << "result " << result << '\n';
|
|
||||||
|
|
||||||
auto & column = static_cast<ColumnVector<Float64> &>(to);
|
|
||||||
column.getData().push_back(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
@ -0,0 +1,59 @@
|
|||||||
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||||
|
#include <AggregateFunctions/CrossTab.h>
|
||||||
|
#include <AggregateFunctions/FactoryHelpers.h>
|
||||||
|
#include <AggregateFunctions/Helpers.h>
|
||||||
|
#include <memory>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
|
||||||
|
struct CramersVBiasCorrectedData : CrossTabData
|
||||||
|
{
|
||||||
|
static const char * getName()
|
||||||
|
{
|
||||||
|
return "cramersVBiasCorrected";
|
||||||
|
}
|
||||||
|
|
||||||
|
Float64 getResult() const
|
||||||
|
{
|
||||||
|
if (count < 2)
|
||||||
|
return std::numeric_limits<Float64>::quiet_NaN();
|
||||||
|
|
||||||
|
Float64 phi = 0.0;
|
||||||
|
for (const auto & [key, value_ab] : count_ab)
|
||||||
|
{
|
||||||
|
Float64 value_a = count_a.at(key.items[0]);
|
||||||
|
Float64 value_b = count_b.at(key.items[1]);
|
||||||
|
|
||||||
|
phi += value_ab * value_ab / (value_a * value_b) * count - 2 * value_ab + (value_a * value_b) / count;
|
||||||
|
}
|
||||||
|
|
||||||
|
phi /= count;
|
||||||
|
|
||||||
|
Float64 res = std::max(0.0, phi - (static_cast<Float64>(count_a.size()) - 1) * (static_cast<Float64>(count_b.size()) - 1) / (count - 1));
|
||||||
|
Float64 correction_a = count_a.size() - (static_cast<Float64>(count_a.size()) - 1) * (static_cast<Float64>(count_a.size()) - 1) / (count - 1);
|
||||||
|
Float64 correction_b = count_b.size() - (static_cast<Float64>(count_b.size()) - 1) * (static_cast<Float64>(count_b.size()) - 1) / (count - 1);
|
||||||
|
res /= std::min(correction_a, correction_b) - 1;
|
||||||
|
|
||||||
|
return sqrt(res);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void registerAggregateFunctionCramersVBiasCorrected(AggregateFunctionFactory & factory)
|
||||||
|
{
|
||||||
|
factory.registerFunction(CramersVBiasCorrectedData::getName(),
|
||||||
|
[](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||||
|
{
|
||||||
|
assertNoParameters(name, parameters);
|
||||||
|
return std::make_shared<AggregateFunctionCrossTab<CramersVBiasCorrectedData>>(argument_types);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1,67 +0,0 @@
|
|||||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
|
||||||
#include <AggregateFunctions/AggregateFunctionCramersV.h>
|
|
||||||
#include <AggregateFunctions/FactoryHelpers.h>
|
|
||||||
#include <AggregateFunctions/Helpers.h>
|
|
||||||
#include "registerAggregateFunctions.h"
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
|
|
||||||
namespace ErrorCodes
|
|
||||||
{
|
|
||||||
extern const int BAD_ARGUMENTS;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace DB
|
|
||||||
{
|
|
||||||
namespace
|
|
||||||
{
|
|
||||||
|
|
||||||
|
|
||||||
struct BiasCorrectionData : public AggregateFunctionCramersVData
|
|
||||||
{
|
|
||||||
Float64 get_result() const
|
|
||||||
{
|
|
||||||
if (cur_size < 2){
|
|
||||||
throw Exception("Aggregate function cramer's v bias corrected at least 2 values in columns", ErrorCodes::BAD_ARGUMENTS);
|
|
||||||
}
|
|
||||||
Float64 phi = 0.0;
|
|
||||||
for (const auto & cell : pairs) {
|
|
||||||
UInt128 hash_pair = cell.getKey();
|
|
||||||
UInt64 count_of_pair_tmp = cell.getMapped();
|
|
||||||
Float64 count_of_pair = Float64(count_of_pair_tmp);
|
|
||||||
UInt64 hash1 = (hash_pair << 64 >> 64);
|
|
||||||
UInt64 hash2 = (hash_pair >> 64);
|
|
||||||
|
|
||||||
UInt64 count1_tmp = n_i.find(hash1)->getMapped();
|
|
||||||
UInt64 count2_tmp = n_j.find(hash2)->getMapped();
|
|
||||||
Float64 count1 = static_cast<Float64>(count1_tmp);
|
|
||||||
Float64 count2 = Float64(count2_tmp);
|
|
||||||
|
|
||||||
phi += ((count_of_pair * count_of_pair / (count1 * count2) * cur_size)
|
|
||||||
- 2 * count_of_pair + (count1 * count2 / cur_size));
|
|
||||||
}
|
|
||||||
phi /= cur_size;
|
|
||||||
Float64 answ = std::max(0.0, phi - ((static_cast<Float64>(n_i.size()) - 1) * (static_cast<Float64>(n_j.size()) - 1) / (cur_size - 1)));
|
|
||||||
Float64 k = n_i.size() - (static_cast<Float64>(n_i.size()) - 1) * (static_cast<Float64>(n_i.size()) - 1) / (cur_size - 1);
|
|
||||||
Float64 r = n_j.size() - (static_cast<Float64>(n_j.size()) - 1) * (static_cast<Float64>(n_j.size()) - 1) / (cur_size - 1);
|
|
||||||
Float64 q = std::min(k, r);
|
|
||||||
answ /= (q - 1);
|
|
||||||
return sqrt(answ);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
AggregateFunctionPtr createAggregateFunctionCramersVBiasCorrection(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
|
||||||
{
|
|
||||||
assertNoParameters(name, parameters);
|
|
||||||
return std::make_shared<AggregateFunctionCramersV<BiasCorrectionData>>(argument_types);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
void registerAggregateFunctionCramersVBiasCorrection(AggregateFunctionFactory & factory)
|
|
||||||
{
|
|
||||||
factory.registerFunction("CramersVBiasCorrection", createAggregateFunctionCramersVBiasCorrection);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,48 +1,43 @@
|
|||||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||||
#include <AggregateFunctions/AggregateFunctionCramersV.h>
|
#include <AggregateFunctions/CrossTab.h>
|
||||||
#include <AggregateFunctions/FactoryHelpers.h>
|
#include <AggregateFunctions/FactoryHelpers.h>
|
||||||
#include <AggregateFunctions/Helpers.h>
|
#include <AggregateFunctions/Helpers.h>
|
||||||
#include "registerAggregateFunctions.h"
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
namespace ErrorCodes
|
|
||||||
{
|
|
||||||
extern const int BAD_ARGUMENTS;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace DB
|
namespace DB
|
||||||
{
|
{
|
||||||
|
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
|
|
||||||
|
struct TheilsUData : CrossTabData
|
||||||
struct TheilsUData : public AggregateFunctionCramersVData
|
|
||||||
{
|
{
|
||||||
Float64 get_result() const
|
static const char * getName()
|
||||||
{
|
{
|
||||||
if (cur_size < 2){
|
return "theilsU";
|
||||||
throw Exception("Aggregate function theil's u requires at least 2 values in columns", ErrorCodes::BAD_ARGUMENTS);
|
|
||||||
}
|
|
||||||
Float64 h_x = 0.0;
|
|
||||||
for (const auto & cell : n_i) {
|
|
||||||
UInt64 count_x_tmp = cell.getMapped();
|
|
||||||
Float64 count_x = Float64(count_x_tmp);
|
|
||||||
h_x += (count_x / cur_size) * (log(count_x / cur_size));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Float64 getResult() const
|
||||||
|
{
|
||||||
|
if (count < 2)
|
||||||
|
return std::numeric_limits<Float64>::quiet_NaN();
|
||||||
|
|
||||||
|
Float64 h_a = 0.0;
|
||||||
|
for (const auto & [key, value] : count_a)
|
||||||
|
{
|
||||||
|
Float64 value_float = value;
|
||||||
|
h_a += (value_float / count) * log(value_float / count);
|
||||||
|
}
|
||||||
|
|
||||||
Float64 dep = 0.0;
|
Float64 dep = 0.0;
|
||||||
for (const auto & cell : pairs) {
|
for (const auto & [key, value] : count_ab)
|
||||||
UInt128 hash_pair = cell.getKey();
|
{
|
||||||
UInt64 count_of_pair_tmp = cell.getMapped();
|
Float64 value_ab = value;
|
||||||
Float64 count_of_pair = Float64(count_of_pair_tmp);
|
Float64 value_b = count_b.at(key.items[1]);
|
||||||
|
|
||||||
UInt64 hash2 = (hash_pair >> 64);
|
dep += (value_ab / count) * log(value_ab / value_b);
|
||||||
|
|
||||||
UInt64 count2_tmp = n_j.find(hash2)->getMapped();
|
|
||||||
Float64 count2 = Float64 (count2_tmp);
|
|
||||||
|
|
||||||
dep += (count_of_pair / cur_size) * log(count_of_pair / count2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dep -= h_x;
|
dep -= h_x;
|
||||||
@ -51,18 +46,16 @@ namespace
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
AggregateFunctionPtr createAggregateFunctionTheilsU(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
|
||||||
{
|
|
||||||
assertNoParameters(name, parameters);
|
|
||||||
return std::make_shared<AggregateFunctionCramersV<TheilsUData>>(argument_types);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
void registerAggregateFunctionTheilsU(AggregateFunctionFactory & factory)
|
void registerAggregateFunctionTheilsU(AggregateFunctionFactory & factory)
|
||||||
{
|
{
|
||||||
factory.registerFunction("TheilsU", createAggregateFunctionTheilsU);
|
factory.registerFunction(TheilsUData::getName(),
|
||||||
|
[](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||||
|
{
|
||||||
|
assertNoParameters(name, parameters);
|
||||||
|
return std::make_shared<AggregateFunctionCrossTab<TheilsUData>>(argument_types);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
134
src/AggregateFunctions/CrossTab.h
Normal file
134
src/AggregateFunctions/CrossTab.h
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <AggregateFunctions/IAggregateFunction.h>
|
||||||
|
#include <Common/assert_cast.h>
|
||||||
|
#include <DataTypes/DataTypesNumber.h>
|
||||||
|
#include <Common/HashTable/HashMap.h>
|
||||||
|
#include <AggregateFunctions/UniqVariadicHash.h>
|
||||||
|
|
||||||
|
|
||||||
|
/** Aggregate function that calculates statistics on top of cross-tab:
|
||||||
|
* - histogram of every argument and every pair of elements.
|
||||||
|
* These statistics include:
|
||||||
|
* - Cramer's V;
|
||||||
|
* - Theil's U;
|
||||||
|
* - contingency coefficient;
|
||||||
|
* It can be interpreted as interdependency coefficient between arguments;
|
||||||
|
* or non-parametric correlation coefficient.
|
||||||
|
*/
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
struct CrossTabData
|
||||||
|
{
|
||||||
|
/// Total count.
|
||||||
|
UInt64 count = 0;
|
||||||
|
|
||||||
|
/// Count of every value of the first and second argument (values are pre-hashed).
|
||||||
|
/// Note: non-cryptographic 64bit hash is used, it means that the calculation is approximate.
|
||||||
|
HashMapWithStackMemory<UInt64, UInt64, TrivialHash, 4> count_a;
|
||||||
|
HashMapWithStackMemory<UInt64, UInt64, TrivialHash, 4> count_b;
|
||||||
|
|
||||||
|
/// Count of every pair of values. We pack two hashes into UInt128.
|
||||||
|
HashMapWithStackMemory<UInt128, UInt64, UInt128Hash, 4> count_ab;
|
||||||
|
|
||||||
|
|
||||||
|
void add(UInt64 hash1, UInt64 hash2)
|
||||||
|
{
|
||||||
|
++count;
|
||||||
|
++count_a[hash1];
|
||||||
|
++count_b[hash2];
|
||||||
|
|
||||||
|
UInt128 hash_pair{hash1, hash2};
|
||||||
|
++count_ab[hash_pair];
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge(const CrossTabData & other)
|
||||||
|
{
|
||||||
|
count += other.count;
|
||||||
|
for (const auto & [key, value] : other.count_a)
|
||||||
|
count_a[key] += value;
|
||||||
|
for (const auto & [key, value] : other.count_b)
|
||||||
|
count_b[key] += value;
|
||||||
|
for (const auto & [key, value] : other.count_ab)
|
||||||
|
count_ab[key] += value;
|
||||||
|
}
|
||||||
|
|
||||||
|
void serialize(WriteBuffer &buf) const
|
||||||
|
{
|
||||||
|
writeBinary(count, buf);
|
||||||
|
count_a.write(buf);
|
||||||
|
count_b.write(buf);
|
||||||
|
count_ab.write(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void deserialize(ReadBuffer &buf)
|
||||||
|
{
|
||||||
|
readBinary(count, buf);
|
||||||
|
count_a.read(buf);
|
||||||
|
count_b.read(buf);
|
||||||
|
count_ab.read(buf);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template <typename Data>
|
||||||
|
class AggregateFunctionCrossTab : public IAggregateFunctionDataHelper<Data, AggregateFunctionCrossTab<Data>>
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
AggregateFunctionCrossTab(const DataTypes & arguments)
|
||||||
|
: IAggregateFunctionDataHelper<Data, AggregateFunctionCrossTab<Data>>({arguments}, {})
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
String getName() const override
|
||||||
|
{
|
||||||
|
return Data::getName();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool allocatesMemoryInArena() const override
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
DataTypePtr getReturnType() const override
|
||||||
|
{
|
||||||
|
return std::make_shared<DataTypeNumber<Float64>>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void add(
|
||||||
|
AggregateDataPtr __restrict place,
|
||||||
|
const IColumn ** columns,
|
||||||
|
size_t row_num,
|
||||||
|
Arena *) const override
|
||||||
|
{
|
||||||
|
UInt64 hash1 = UniqVariadicHash<false, false>::apply(1, &columns[0], row_num);
|
||||||
|
UInt64 hash2 = UniqVariadicHash<false, false>::apply(1, &columns[1], row_num);
|
||||||
|
|
||||||
|
this->data(place).add(hash1, hash2);
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||||
|
{
|
||||||
|
this->data(place).merge(this->data(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf) const override
|
||||||
|
{
|
||||||
|
this->data(place).serialize(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, Arena *) const override
|
||||||
|
{
|
||||||
|
this->data(place).deserialize(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
|
||||||
|
{
|
||||||
|
Float64 result = this->data(place).getResult();
|
||||||
|
auto & column = static_cast<ColumnVector<Float64> &>(to);
|
||||||
|
column.getData().push_back(result);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
@ -50,8 +50,8 @@ void registerAggregateFunctionWelchTTest(AggregateFunctionFactory &);
|
|||||||
void registerAggregateFunctionStudentTTest(AggregateFunctionFactory &);
|
void registerAggregateFunctionStudentTTest(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionCramersV(AggregateFunctionFactory &);
|
void registerAggregateFunctionCramersV(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionTheilsU(AggregateFunctionFactory &);
|
void registerAggregateFunctionTheilsU(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionContingencyCoefficient(AggregateFunctionFactory &);
|
void registerAggregateFunctionContingency(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionCramersVBiasCorrection(AggregateFunctionFactory &);
|
void registerAggregateFunctionCramersVBiasCorrected(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionSingleValueOrNull(AggregateFunctionFactory &);
|
void registerAggregateFunctionSingleValueOrNull(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionSequenceNextNode(AggregateFunctionFactory &);
|
void registerAggregateFunctionSequenceNextNode(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionExponentialMovingAverage(AggregateFunctionFactory &);
|
void registerAggregateFunctionExponentialMovingAverage(AggregateFunctionFactory &);
|
||||||
@ -105,8 +105,8 @@ void registerAggregateFunctions()
|
|||||||
registerAggregateFunctionsBitwise(factory);
|
registerAggregateFunctionsBitwise(factory);
|
||||||
registerAggregateFunctionCramersV(factory);
|
registerAggregateFunctionCramersV(factory);
|
||||||
registerAggregateFunctionTheilsU(factory);
|
registerAggregateFunctionTheilsU(factory);
|
||||||
registerAggregateFunctionContingencyCoefficient(factory);
|
registerAggregateFunctionContingency(factory);
|
||||||
registerAggregateFunctionCramersVBiasCorrection(factory);
|
registerAggregateFunctionCramersVBiasCorrected(factory);
|
||||||
registerAggregateFunctionsBitmap(factory);
|
registerAggregateFunctionsBitmap(factory);
|
||||||
registerAggregateFunctionsMaxIntersections(factory);
|
registerAggregateFunctionsMaxIntersections(factory);
|
||||||
registerAggregateFunctionHistogram(factory);
|
registerAggregateFunctionHistogram(factory);
|
||||||
|
@ -262,6 +262,13 @@ public:
|
|||||||
|
|
||||||
return it->getMapped();
|
return it->getMapped();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
typename Cell::Mapped & ALWAYS_INLINE at(const Key & x) const
|
||||||
|
{
|
||||||
|
if (auto it = this->find(x); it != this->end())
|
||||||
|
return it->getMapped();
|
||||||
|
throw DB::Exception("Cannot find element in HashMap::at method", DB::ErrorCodes::LOGICAL_ERROR);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace std
|
namespace std
|
||||||
|
Loading…
Reference in New Issue
Block a user