diff --git a/src/AggregateFunctions/AggregateFunctionContingencyCoefficient.cpp b/src/AggregateFunctions/AggregateFunctionContingencyCoefficient.cpp index 0391fe3c8ee..4d34c14ede6 100644 --- a/src/AggregateFunctions/AggregateFunctionContingencyCoefficient.cpp +++ b/src/AggregateFunctions/AggregateFunctionContingencyCoefficient.cpp @@ -1,61 +1,53 @@ #include -#include +#include #include #include -#include "registerAggregateFunctions.h" #include +#include -namespace ErrorCodes -{ -extern const int BAD_ARGUMENTS; -} namespace DB { + namespace { - -struct ContingencyData : public AggregateFunctionCramersVData +struct ContingencyData : CrossTabData { - Float64 get_result() const + static const char * getName() { - if (cur_size < 2){ - throw Exception("Aggregate function contingency coefficient requires at least 2 values in columns", ErrorCodes::BAD_ARGUMENTS); - } + return "contingency"; + } + + Float64 getResult() const + { + if (count < 2) + return std::numeric_limits::quiet_NaN(); + Float64 phi = 0.0; - for (const auto & cell : pairs) { - UInt128 hash_pair = cell.getKey(); - UInt64 count_of_pair_tmp = cell.getMapped(); - Float64 count_of_pair = Float64(count_of_pair_tmp); - UInt64 hash1 = (hash_pair << 64 >> 64); - UInt64 hash2 = (hash_pair >> 64); + for (const auto & [key, value_ab] : count_ab) + { + Float64 value_a = count_a.at(key.items[0]); + Float64 value_b = count_b.at(key.items[1]); - UInt64 count1_tmp = n_i.find(hash1)->getMapped(); - UInt64 count2_tmp = n_j.find(hash2)->getMapped(); - Float64 count1 = static_cast(count1_tmp); - Float64 count2 = Float64(count2_tmp); - - phi += ((count_of_pair * count_of_pair / (count1 * count2) * cur_size) - - 2 * count_of_pair + (count1 * count2 / cur_size)); + phi += value_ab * value_ab / (value_a * value_b) * count - 2 * value_ab + (value_a * value_b) / count; } - phi /= cur_size; - return sqrt(phi / (phi + cur_size)); + phi /= count; + + return sqrt(phi / (phi + count)); } }; - -AggregateFunctionPtr createAggregateFunctionContingencyCoefficient(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *) +void registerAggregateFunctionContingency(AggregateFunctionFactory & factory) { - assertNoParameters(name, parameters); - return std::make_shared>(argument_types); + factory.registerFunction(ContingencyData::getName(), + [](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *) + { + assertNoParameters(name, parameters); + return std::make_shared>(argument_types); + }); } } -void registerAggregateFunctionContingencyCoefficient(AggregateFunctionFactory & factory) -{ - factory.registerFunction("ContingencyCoefficient", createAggregateFunctionContingencyCoefficient); -} - } diff --git a/src/AggregateFunctions/AggregateFunctionCramersV.cpp b/src/AggregateFunctions/AggregateFunctionCramersV.cpp index b04c6a37056..850070f26b2 100644 --- a/src/AggregateFunctions/AggregateFunctionCramersV.cpp +++ b/src/AggregateFunctions/AggregateFunctionCramersV.cpp @@ -1,27 +1,56 @@ #include -#include +#include #include #include -#include "registerAggregateFunctions.h" #include +#include + namespace DB { + namespace { -AggregateFunctionPtr createAggregateFunctionCramersV(const std::string & name, const DataTypes & argument_types, - const Array & parameters, const Settings *) +struct CramersVData : CrossTabData { - assertNoParameters(name, parameters); - return std::make_shared>(argument_types); -} + static const char * getName() + { + return "cramersV"; + } + + Float64 getResult() const + { + if (count < 2) + return std::numeric_limits::quiet_NaN(); + + Float64 phi = 0.0; + for (const auto & [key, value_ab] : count_ab) + { + Float64 value_a = count_a.at(key.items[0]); + Float64 value_b = count_b.at(key.items[1]); + + phi += value_ab * value_ab / (value_a * value_b) * count - 2 * value_ab + (value_a * value_b) / count; + } + + phi /= count; + UInt64 q = std::min(count_a.size(), count_b.size()); + phi /= q - 1; + + return sqrt(phi); + } +}; } void registerAggregateFunctionCramersV(AggregateFunctionFactory & factory) { - factory.registerFunction("CramersV", createAggregateFunctionCramersV); + factory.registerFunction(CramersVData::getName(), + [](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *) + { + assertNoParameters(name, parameters); + return std::make_shared>(argument_types); + }); } } diff --git a/src/AggregateFunctions/AggregateFunctionCramersV.h b/src/AggregateFunctions/AggregateFunctionCramersV.h deleted file mode 100644 index 383647f8aa9..00000000000 --- a/src/AggregateFunctions/AggregateFunctionCramersV.h +++ /dev/null @@ -1,207 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - - -namespace ErrorCodes -{ -extern const int BAD_ARGUMENTS; -} - -namespace DB -{ - - - struct AggregateFunctionCramersVData - { - size_t cur_size = 0; - HashMap n_i; - HashMap n_j; - HashMap pairs; - - - void add(UInt64 hash1, UInt64 hash2) - { - cur_size += 1; - n_i[hash1] += 1; - n_j[hash2] += 1; - - UInt128 hash_pair = hash1 | (static_cast(hash2) << 64); - pairs[hash_pair] += 1; - - } - - void merge(const AggregateFunctionCramersVData &other) - { - cur_size += other.cur_size; - - for (const auto& pair : other.n_i) { - UInt64 hash1 = pair.getKey(); - UInt64 count = pair.getMapped(); - n_i[hash1] += count; - } - for (const auto& pair : other.n_j) { - UInt64 hash1 = pair.getKey(); - UInt64 count = pair.getMapped(); - n_j[hash1] += count; - } - for (const auto& pair : other.pairs) { - UInt128 hash1 = pair.getKey(); - UInt64 count = pair.getMapped(); - pairs[hash1] += count; - } - } - - void serialize(WriteBuffer &buf) const - { - writeBinary(cur_size, buf); - n_i.write(buf); - n_j.write(buf); - pairs.write(buf); - } - - void deserialize(ReadBuffer &buf) - { - readBinary(cur_size, buf); - n_i.read(buf); - n_j.read(buf); - pairs.read(buf); - } - - Float64 get_result() const - { - if (cur_size < 2){ - throw Exception("Aggregate function cramer's v requires et least 2 values in columns", ErrorCodes::BAD_ARGUMENTS); - } - Float64 phi = 0.0; - for (const auto & cell : pairs) { - UInt128 hash_pair = cell.getKey(); - UInt64 count_of_pair_tmp = cell.getMapped(); - Float64 count_of_pair = Float64(count_of_pair_tmp); - UInt64 hash1 = (hash_pair << 64 >> 64); - UInt64 hash2 = (hash_pair >> 64); - - UInt64 count1_tmp = n_i.find(hash1)->getMapped(); - UInt64 count2_tmp = n_j.find(hash2)->getMapped(); - Float64 count1 = static_cast(count1_tmp); - Float64 count2 = Float64(count2_tmp); - - phi += ((count_of_pair * count_of_pair / (count1 * count2) * cur_size) - - 2 * count_of_pair + (count1 * count2 / cur_size)); - } - phi /= cur_size; - - UInt64 q = std::min(n_i.size(), n_j.size()); - phi /= (q - 1); - return sqrt(phi); - - } - }; - - template - - class AggregateFunctionCramersV : public - IAggregateFunctionDataHelper< - Data, - AggregateFunctionCramersV - > - { - - public: - AggregateFunctionCramersV( - const DataTypes & arguments - ): - IAggregateFunctionDataHelper< - Data, - AggregateFunctionCramersV - > ({arguments}, {}) - { - // notice: arguments has been in factory - } - - String getName() const override - { - return "CramersV"; - } - - bool allocatesMemoryInArena() const override { return false; } - - DataTypePtr getReturnType() const override - { - return std::make_shared>(); - } - - void add( - AggregateDataPtr __restrict place, - const IColumn ** columns, - size_t row_num, - Arena * - ) const override - { - UInt64 hash1 = UniqVariadicHash::apply(1, columns, row_num); - UInt64 hash2 = UniqVariadicHash::apply(1, columns + 1, row_num); - - this->data(place).add(hash1, hash2); - } - - void merge( - AggregateDataPtr __restrict place, - ConstAggregateDataPtr rhs, Arena * - ) const override - { - this->data(place).merge(this->data(rhs)); - } - - void serialize( - ConstAggregateDataPtr __restrict place, - WriteBuffer & buf - ) const override - { - this->data(place).serialize(buf); - } - - void deserialize( - AggregateDataPtr __restrict place, - ReadBuffer & buf, Arena * - ) const override - { - this->data(place).deserialize(buf); - } - - void insertResultInto( - AggregateDataPtr __restrict place, - IColumn & to, - Arena * - ) const override - { - Float64 result = this->data(place).get_result(); -// std::cerr << "cur_size" << this->data(place).cur_size << '\n'; -// std::cerr << "n_i size" << this->data(place).n_i.size() << '\n'; -// std::cerr << "n_j size" << this->data(place).n_j.size() << '\n'; -// std::cerr << "pair size " << this->data(place).pairs.size() << '\n'; -// std::cerr << "result " << result << '\n'; - - auto & column = static_cast &>(to); - column.getData().push_back(result); - } - - }; - -} diff --git a/src/AggregateFunctions/AggregateFunctionCramersVBiasCorrected.cpp b/src/AggregateFunctions/AggregateFunctionCramersVBiasCorrected.cpp new file mode 100644 index 00000000000..48a3029c399 --- /dev/null +++ b/src/AggregateFunctions/AggregateFunctionCramersVBiasCorrected.cpp @@ -0,0 +1,59 @@ +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace +{ + +struct CramersVBiasCorrectedData : CrossTabData +{ + static const char * getName() + { + return "cramersVBiasCorrected"; + } + + Float64 getResult() const + { + if (count < 2) + return std::numeric_limits::quiet_NaN(); + + Float64 phi = 0.0; + for (const auto & [key, value_ab] : count_ab) + { + Float64 value_a = count_a.at(key.items[0]); + Float64 value_b = count_b.at(key.items[1]); + + phi += value_ab * value_ab / (value_a * value_b) * count - 2 * value_ab + (value_a * value_b) / count; + } + + phi /= count; + + Float64 res = std::max(0.0, phi - (static_cast(count_a.size()) - 1) * (static_cast(count_b.size()) - 1) / (count - 1)); + Float64 correction_a = count_a.size() - (static_cast(count_a.size()) - 1) * (static_cast(count_a.size()) - 1) / (count - 1); + Float64 correction_b = count_b.size() - (static_cast(count_b.size()) - 1) * (static_cast(count_b.size()) - 1) / (count - 1); + res /= std::min(correction_a, correction_b) - 1; + + return sqrt(res); + } +}; + +} + +void registerAggregateFunctionCramersVBiasCorrected(AggregateFunctionFactory & factory) +{ + factory.registerFunction(CramersVBiasCorrectedData::getName(), + [](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *) + { + assertNoParameters(name, parameters); + return std::make_shared>(argument_types); + }); +} + +} diff --git a/src/AggregateFunctions/AggregateFunctionCramersVBiasCorrection.cpp b/src/AggregateFunctions/AggregateFunctionCramersVBiasCorrection.cpp deleted file mode 100644 index c58ca8a59da..00000000000 --- a/src/AggregateFunctions/AggregateFunctionCramersVBiasCorrection.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include -#include -#include -#include -#include "registerAggregateFunctions.h" -#include - - -namespace ErrorCodes -{ - extern const int BAD_ARGUMENTS; -} - -namespace DB -{ -namespace -{ - - -struct BiasCorrectionData : public AggregateFunctionCramersVData -{ - Float64 get_result() const - { - if (cur_size < 2){ - throw Exception("Aggregate function cramer's v bias corrected at least 2 values in columns", ErrorCodes::BAD_ARGUMENTS); - } - Float64 phi = 0.0; - for (const auto & cell : pairs) { - UInt128 hash_pair = cell.getKey(); - UInt64 count_of_pair_tmp = cell.getMapped(); - Float64 count_of_pair = Float64(count_of_pair_tmp); - UInt64 hash1 = (hash_pair << 64 >> 64); - UInt64 hash2 = (hash_pair >> 64); - - UInt64 count1_tmp = n_i.find(hash1)->getMapped(); - UInt64 count2_tmp = n_j.find(hash2)->getMapped(); - Float64 count1 = static_cast(count1_tmp); - Float64 count2 = Float64(count2_tmp); - - phi += ((count_of_pair * count_of_pair / (count1 * count2) * cur_size) - - 2 * count_of_pair + (count1 * count2 / cur_size)); - } - phi /= cur_size; - Float64 answ = std::max(0.0, phi - ((static_cast(n_i.size()) - 1) * (static_cast(n_j.size()) - 1) / (cur_size - 1))); - Float64 k = n_i.size() - (static_cast(n_i.size()) - 1) * (static_cast(n_i.size()) - 1) / (cur_size - 1); - Float64 r = n_j.size() - (static_cast(n_j.size()) - 1) * (static_cast(n_j.size()) - 1) / (cur_size - 1); - Float64 q = std::min(k, r); - answ /= (q - 1); - return sqrt(answ); - } -}; - - -AggregateFunctionPtr createAggregateFunctionCramersVBiasCorrection(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *) -{ - assertNoParameters(name, parameters); - return std::make_shared>(argument_types); -} - -} - -void registerAggregateFunctionCramersVBiasCorrection(AggregateFunctionFactory & factory) -{ - factory.registerFunction("CramersVBiasCorrection", createAggregateFunctionCramersVBiasCorrection); -} - -} diff --git a/src/AggregateFunctions/AggregateFunctionTheilsU.cpp b/src/AggregateFunctions/AggregateFunctionTheilsU.cpp index b2eeff3d7c9..3868e57ad6b 100644 --- a/src/AggregateFunctions/AggregateFunctionTheilsU.cpp +++ b/src/AggregateFunctions/AggregateFunctionTheilsU.cpp @@ -1,48 +1,43 @@ #include -#include +#include #include #include -#include "registerAggregateFunctions.h" #include +#include -namespace ErrorCodes -{ -extern const int BAD_ARGUMENTS; -} namespace DB { + namespace { - - struct TheilsUData : public AggregateFunctionCramersVData +struct TheilsUData : CrossTabData { - Float64 get_result() const + static const char * getName() { - if (cur_size < 2){ - throw Exception("Aggregate function theil's u requires at least 2 values in columns", ErrorCodes::BAD_ARGUMENTS); - } - Float64 h_x = 0.0; - for (const auto & cell : n_i) { - UInt64 count_x_tmp = cell.getMapped(); - Float64 count_x = Float64(count_x_tmp); - h_x += (count_x / cur_size) * (log(count_x / cur_size)); - } + return "theilsU"; + } + Float64 getResult() const + { + if (count < 2) + return std::numeric_limits::quiet_NaN(); + + Float64 h_a = 0.0; + for (const auto & [key, value] : count_a) + { + Float64 value_float = value; + h_a += (value_float / count) * log(value_float / count); + } Float64 dep = 0.0; - for (const auto & cell : pairs) { - UInt128 hash_pair = cell.getKey(); - UInt64 count_of_pair_tmp = cell.getMapped(); - Float64 count_of_pair = Float64(count_of_pair_tmp); + for (const auto & [key, value] : count_ab) + { + Float64 value_ab = value; + Float64 value_b = count_b.at(key.items[1]); - UInt64 hash2 = (hash_pair >> 64); - - UInt64 count2_tmp = n_j.find(hash2)->getMapped(); - Float64 count2 = Float64 (count2_tmp); - - dep += (count_of_pair / cur_size) * log(count_of_pair / count2); + dep += (value_ab / count) * log(value_ab / value_b); } dep -= h_x; @@ -51,18 +46,16 @@ namespace } }; - -AggregateFunctionPtr createAggregateFunctionTheilsU(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *) -{ - assertNoParameters(name, parameters); - return std::make_shared>(argument_types); -} - -} - void registerAggregateFunctionTheilsU(AggregateFunctionFactory & factory) { - factory.registerFunction("TheilsU", createAggregateFunctionTheilsU); + factory.registerFunction(TheilsUData::getName(), + [](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *) + { + assertNoParameters(name, parameters); + return std::make_shared>(argument_types); + }); +} + } } diff --git a/src/AggregateFunctions/CrossTab.h b/src/AggregateFunctions/CrossTab.h new file mode 100644 index 00000000000..4fc6dc2c21b --- /dev/null +++ b/src/AggregateFunctions/CrossTab.h @@ -0,0 +1,134 @@ +#pragma once + +#include +#include +#include +#include +#include + + +/** Aggregate function that calculates statistics on top of cross-tab: + * - histogram of every argument and every pair of elements. + * These statistics include: + * - Cramer's V; + * - Theil's U; + * - contingency coefficient; + * It can be interpreted as interdependency coefficient between arguments; + * or non-parametric correlation coefficient. + */ +namespace DB +{ + +struct CrossTabData +{ + /// Total count. + UInt64 count = 0; + + /// Count of every value of the first and second argument (values are pre-hashed). + /// Note: non-cryptographic 64bit hash is used, it means that the calculation is approximate. + HashMapWithStackMemory count_a; + HashMapWithStackMemory count_b; + + /// Count of every pair of values. We pack two hashes into UInt128. + HashMapWithStackMemory count_ab; + + + void add(UInt64 hash1, UInt64 hash2) + { + ++count; + ++count_a[hash1]; + ++count_b[hash2]; + + UInt128 hash_pair{hash1, hash2}; + ++count_ab[hash_pair]; + } + + void merge(const CrossTabData & other) + { + count += other.count; + for (const auto & [key, value] : other.count_a) + count_a[key] += value; + for (const auto & [key, value] : other.count_b) + count_b[key] += value; + for (const auto & [key, value] : other.count_ab) + count_ab[key] += value; + } + + void serialize(WriteBuffer &buf) const + { + writeBinary(count, buf); + count_a.write(buf); + count_b.write(buf); + count_ab.write(buf); + } + + void deserialize(ReadBuffer &buf) + { + readBinary(count, buf); + count_a.read(buf); + count_b.read(buf); + count_ab.read(buf); + } +}; + + +template +class AggregateFunctionCrossTab : public IAggregateFunctionDataHelper> +{ +public: + AggregateFunctionCrossTab(const DataTypes & arguments) + : IAggregateFunctionDataHelper>({arguments}, {}) + { + } + + String getName() const override + { + return Data::getName(); + } + + bool allocatesMemoryInArena() const override + { + return false; + } + + DataTypePtr getReturnType() const override + { + return std::make_shared>(); + } + + void add( + AggregateDataPtr __restrict place, + const IColumn ** columns, + size_t row_num, + Arena *) const override + { + UInt64 hash1 = UniqVariadicHash::apply(1, &columns[0], row_num); + UInt64 hash2 = UniqVariadicHash::apply(1, &columns[1], row_num); + + this->data(place).add(hash1, hash2); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override + { + this->data(place).merge(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf) const override + { + this->data(place).serialize(buf); + } + + void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, Arena *) const override + { + this->data(place).deserialize(buf); + } + + void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override + { + Float64 result = this->data(place).getResult(); + auto & column = static_cast &>(to); + column.getData().push_back(result); + } +}; + +} diff --git a/src/AggregateFunctions/registerAggregateFunctions.cpp b/src/AggregateFunctions/registerAggregateFunctions.cpp index b10f3832e21..c9e46329735 100644 --- a/src/AggregateFunctions/registerAggregateFunctions.cpp +++ b/src/AggregateFunctions/registerAggregateFunctions.cpp @@ -50,8 +50,8 @@ void registerAggregateFunctionWelchTTest(AggregateFunctionFactory &); void registerAggregateFunctionStudentTTest(AggregateFunctionFactory &); void registerAggregateFunctionCramersV(AggregateFunctionFactory &); void registerAggregateFunctionTheilsU(AggregateFunctionFactory &); -void registerAggregateFunctionContingencyCoefficient(AggregateFunctionFactory &); -void registerAggregateFunctionCramersVBiasCorrection(AggregateFunctionFactory &); +void registerAggregateFunctionContingency(AggregateFunctionFactory &); +void registerAggregateFunctionCramersVBiasCorrected(AggregateFunctionFactory &); void registerAggregateFunctionSingleValueOrNull(AggregateFunctionFactory &); void registerAggregateFunctionSequenceNextNode(AggregateFunctionFactory &); void registerAggregateFunctionExponentialMovingAverage(AggregateFunctionFactory &); @@ -105,8 +105,8 @@ void registerAggregateFunctions() registerAggregateFunctionsBitwise(factory); registerAggregateFunctionCramersV(factory); registerAggregateFunctionTheilsU(factory); - registerAggregateFunctionContingencyCoefficient(factory); - registerAggregateFunctionCramersVBiasCorrection(factory); + registerAggregateFunctionContingency(factory); + registerAggregateFunctionCramersVBiasCorrected(factory); registerAggregateFunctionsBitmap(factory); registerAggregateFunctionsMaxIntersections(factory); registerAggregateFunctionHistogram(factory); diff --git a/src/Common/HashTable/HashMap.h b/src/Common/HashTable/HashMap.h index c5675d4d7c9..e619421b8f7 100644 --- a/src/Common/HashTable/HashMap.h +++ b/src/Common/HashTable/HashMap.h @@ -262,6 +262,13 @@ public: return it->getMapped(); } + + typename Cell::Mapped & ALWAYS_INLINE at(const Key & x) const + { + if (auto it = this->find(x); it != this->end()) + return it->getMapped(); + throw DB::Exception("Cannot find element in HashMap::at method", DB::ErrorCodes::LOGICAL_ERROR); + } }; namespace std