From e2d434ea18c771ced694adad2fc6cf3d0241adc5 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sat, 17 Sep 2022 23:37:57 +0200 Subject: [PATCH 1/4] Start with code removal --- ...ateFunctionCategoricalInformationValue.cpp | 130 ++++++++++++++++- ...egateFunctionCategoricalInformationValue.h | 135 ------------------ .../Serializations/SerializationArray.cpp | 2 +- 3 files changed, 126 insertions(+), 141 deletions(-) delete mode 100644 src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.h diff --git a/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp b/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp index 35654c08659..6fe0de34686 100644 --- a/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp +++ b/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp @@ -1,12 +1,18 @@ -#include - +#include #include #include #include +#include +#include +#include +#include +#include +#include namespace DB { + struct Settings; namespace ErrorCodes @@ -15,6 +21,122 @@ namespace ErrorCodes extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } + +template +class AggregateFunctionCategoricalIV final : public IAggregateFunctionHelper> +{ +private: + size_t category_count; + +public: + AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_) : + IAggregateFunctionHelper>{arguments_, params_}, + category_count{arguments_.size() - 1} + { + // notice: argument types has been checked before + } + + String getName() const override + { + return "categoricalInformationValue"; + } + + bool allocatesMemoryInArena() const override { return false; } + + void create(AggregateDataPtr __restrict place) const override + { + memset(place, 0, sizeOfData()); + } + + void destroy(AggregateDataPtr __restrict) const noexcept override + { + // nothing + } + + bool hasTrivialDestructor() const override + { + return true; + } + + size_t sizeOfData() const override + { + return sizeof(T) * (category_count + 1) * 2; + } + + size_t alignOfData() const override + { + return alignof(T); + } + + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override + { + const auto * y_col = static_cast(columns[category_count]); + bool y = y_col->getData()[row_num]; + + for (size_t i = 0; i < category_count; ++i) + { + const auto * x_col = static_cast(columns[i]); + bool x = x_col->getData()[row_num]; + + if (x) + reinterpret_cast(place)[i * 2 + size_t(y)] += 1; + } + + reinterpret_cast(place)[category_count * 2 + size_t(y)] += 1; + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override + { + for (size_t i = 0; i <= category_count; ++i) + { + reinterpret_cast(place)[i * 2] += reinterpret_cast(rhs)[i * 2]; + reinterpret_cast(place)[i * 2 + 1] += reinterpret_cast(rhs)[i * 2 + 1]; + } + } + + void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional /* version */) const override + { + buf.write(place, sizeOfData()); + } + + void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional /* version */, Arena *) const override + { + buf.read(place, sizeOfData()); + } + + DataTypePtr getReturnType() const override + { + return std::make_shared( + std::make_shared>()); + } + + void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override /// NOLINT + { + auto & col = static_cast(to); + auto & data_col = static_cast(col.getData()); + auto & offset_col = static_cast(col.getOffsetsColumn()); + + data_col.reserve(data_col.size() + category_count); + + T sum_no = reinterpret_cast(place)[category_count * 2]; + T sum_yes = reinterpret_cast(place)[category_count * 2 + 1]; + + Float64 rev_no = 1. / sum_no; + Float64 rev_yes = 1. / sum_yes; + + for (size_t i = 0; i < category_count; ++i) + { + T no = reinterpret_cast(place)[i * 2]; + T yes = reinterpret_cast(place)[i * 2 + 1]; + + data_col.insertValue((no * rev_no - yes * rev_yes) * (log(no * rev_no) - log(yes * rev_yes))); + } + + offset_col.insertValue(data_col.size()); + } +}; + + namespace { @@ -44,9 +166,7 @@ AggregateFunctionPtr createAggregateFunctionCategoricalIV( } -void registerAggregateFunctionCategoricalIV( - AggregateFunctionFactory & factory -) +void registerAggregateFunctionCategoricalIV(AggregateFunctionFactory & factory) { factory.registerFunction("categoricalInformationValue", createAggregateFunctionCategoricalIV); } diff --git a/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.h b/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.h deleted file mode 100644 index 0e0db27cf22..00000000000 --- a/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.h +++ /dev/null @@ -1,135 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include - - -namespace DB -{ -struct Settings; - -template -class AggregateFunctionCategoricalIV final : public IAggregateFunctionHelper> -{ -private: - size_t category_count; - -public: - AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_) : - IAggregateFunctionHelper> {arguments_, params_}, - category_count {arguments_.size() - 1} - { - // notice: argument types has been checked before - } - - String getName() const override - { - return "categoricalInformationValue"; - } - - bool allocatesMemoryInArena() const override { return false; } - - void create(AggregateDataPtr __restrict place) const override - { - memset(place, 0, sizeOfData()); - } - - void destroy(AggregateDataPtr __restrict) const noexcept override - { - // nothing - } - - bool hasTrivialDestructor() const override - { - return true; - } - - size_t sizeOfData() const override - { - return sizeof(T) * (category_count + 1) * 2; - } - - size_t alignOfData() const override - { - return alignof(T); - } - - void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override - { - const auto * y_col = static_cast(columns[category_count]); - bool y = y_col->getData()[row_num]; - - for (size_t i : collections::range(0, category_count)) - { - const auto * x_col = static_cast(columns[i]); - bool x = x_col->getData()[row_num]; - - if (x) - reinterpret_cast(place)[i * 2 + size_t(y)] += 1; - } - - reinterpret_cast(place)[category_count * 2 + size_t(y)] += 1; - } - - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override - { - for (size_t i : collections::range(0, category_count + 1)) - { - reinterpret_cast(place)[i * 2] += reinterpret_cast(rhs)[i * 2]; - reinterpret_cast(place)[i * 2 + 1] += reinterpret_cast(rhs)[i * 2 + 1]; - } - } - - void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional /* version */) const override - { - buf.write(place, sizeOfData()); - } - - void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional /* version */, Arena *) const override - { - buf.read(place, sizeOfData()); - } - - DataTypePtr getReturnType() const override - { - return std::make_shared( - std::make_shared>() - ); - } - - void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override /// NOLINT - { - auto & col = static_cast(to); - auto & data_col = static_cast(col.getData()); - auto & offset_col = static_cast( - col.getOffsetsColumn() - ); - - data_col.reserve(data_col.size() + category_count); - - T sum_no = reinterpret_cast(place)[category_count * 2]; - T sum_yes = reinterpret_cast(place)[category_count * 2 + 1]; - - Float64 rev_no = 1. / sum_no; - Float64 rev_yes = 1. / sum_yes; - - for (size_t i : collections::range(0, category_count)) - { - T no = reinterpret_cast(place)[i * 2]; - T yes = reinterpret_cast(place)[i * 2 + 1]; - - data_col.insertValue((no * rev_no - yes * rev_yes) * (log(no * rev_no) - log(yes * rev_yes))); - } - - offset_col.insertValue(data_col.size()); - } -}; - -} diff --git a/src/DataTypes/Serializations/SerializationArray.cpp b/src/DataTypes/Serializations/SerializationArray.cpp index abd99038e98..75c4c01ac9a 100644 --- a/src/DataTypes/Serializations/SerializationArray.cpp +++ b/src/DataTypes/Serializations/SerializationArray.cpp @@ -174,7 +174,7 @@ namespace { auto current_offset = offsets_data[i]; sizes_data[i] = current_offset - prev_offset; - prev_offset = current_offset; + prev_offset = current_offset; } return column_sizes; From c424ad12aab6e29dbd187ace319c2401504850bb Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 18 Sep 2022 00:09:21 +0200 Subject: [PATCH 2/4] Simplification --- ...ateFunctionCategoricalInformationValue.cpp | 52 ++++++++++++------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp b/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp index 6fe0de34686..fb03abe03fe 100644 --- a/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp +++ b/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp @@ -21,16 +21,33 @@ namespace ErrorCodes extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } - -template -class AggregateFunctionCategoricalIV final : public IAggregateFunctionHelper> +/** The function takes arguments x1, x2, ... xn, y. All arguments are bool. + * x arguments represents the fact that some category is true. + * + * It calculates how many times y was true and how many times y was false when every n-th category was true + * and the total number of times y was true and false. + * + * So, the size of the state is (n + 1) * 2 cells. + */ +class AggregateFunctionCategoricalIV final : public IAggregateFunctionHelper { private: + using Counter = UInt64; size_t category_count; + Counter & counter(AggregateDataPtr __restrict place, size_t i, bool what) const + { + return reinterpret_cast(place)[i * 2 + (what ? 1 : 0)]; + } + + const Counter & counter(ConstAggregateDataPtr __restrict place, size_t i, bool what) const + { + return reinterpret_cast(place)[i * 2 + (what ? 1 : 0)]; + } + public: AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_) : - IAggregateFunctionHelper>{arguments_, params_}, + IAggregateFunctionHelper{arguments_, params_}, category_count{arguments_.size() - 1} { // notice: argument types has been checked before @@ -60,12 +77,12 @@ public: size_t sizeOfData() const override { - return sizeof(T) * (category_count + 1) * 2; + return sizeof(Counter) * (category_count + 1) * 2; } size_t alignOfData() const override { - return alignof(T); + return alignof(Counter); } void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override @@ -79,18 +96,18 @@ public: bool x = x_col->getData()[row_num]; if (x) - reinterpret_cast(place)[i * 2 + size_t(y)] += 1; + ++counter(place, i, y); } - reinterpret_cast(place)[category_count * 2 + size_t(y)] += 1; + ++counter(place, category_count, y); } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override { for (size_t i = 0; i <= category_count; ++i) { - reinterpret_cast(place)[i * 2] += reinterpret_cast(rhs)[i * 2]; - reinterpret_cast(place)[i * 2 + 1] += reinterpret_cast(rhs)[i * 2 + 1]; + counter(place, i, false) += counter(rhs, i, false); + counter(place, i, true) += counter(rhs, i, true); } } @@ -118,18 +135,15 @@ public: data_col.reserve(data_col.size() + category_count); - T sum_no = reinterpret_cast(place)[category_count * 2]; - T sum_yes = reinterpret_cast(place)[category_count * 2 + 1]; - - Float64 rev_no = 1. / sum_no; - Float64 rev_yes = 1. / sum_yes; + Float64 sum_no = static_cast(counter(place, category_count, false)); + Float64 sum_yes = static_cast(counter(place, category_count, true)); for (size_t i = 0; i < category_count; ++i) { - T no = reinterpret_cast(place)[i * 2]; - T yes = reinterpret_cast(place)[i * 2 + 1]; + Float64 no = static_cast(counter(place, i, false)); + Float64 yes = static_cast(counter(place, i, true)); - data_col.insertValue((no * rev_no - yes * rev_yes) * (log(no * rev_no) - log(yes * rev_yes))); + data_col.insertValue((no / sum_no - yes / sum_yes) * (log((no / sum_no) / (yes / sum_yes)))); } offset_col.insertValue(data_col.size()); @@ -161,7 +175,7 @@ AggregateFunctionPtr createAggregateFunctionCategoricalIV( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } - return std::make_shared>(arguments, params); + return std::make_shared(arguments, params); } } From 6f1878b12adca9e4336be69c7a7f4a0e83a89de2 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 18 Sep 2022 01:17:22 +0200 Subject: [PATCH 3/4] Fix error --- .../AggregateFunctionCategoricalInformationValue.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp b/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp index fb03abe03fe..99ffc87e076 100644 --- a/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp +++ b/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp @@ -182,7 +182,8 @@ AggregateFunctionPtr createAggregateFunctionCategoricalIV( void registerAggregateFunctionCategoricalIV(AggregateFunctionFactory & factory) { - factory.registerFunction("categoricalInformationValue", createAggregateFunctionCategoricalIV); + AggregateFunctionProperties properties = { .returns_default_when_only_null = true }; + factory.registerFunction("categoricalInformationValue", { createAggregateFunctionCategoricalIV, properties }); } } From 416b5c701b211111693cbdd3ea04d13889c1ca96 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 18 Sep 2022 01:21:33 +0200 Subject: [PATCH 4/4] Add a test --- ...egorical_information_value_properties.reference | 12 ++++++++++++ ...25_categorical_information_value_properties.sql | 14 ++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 tests/queries/0_stateless/02425_categorical_information_value_properties.reference create mode 100644 tests/queries/0_stateless/02425_categorical_information_value_properties.sql diff --git a/tests/queries/0_stateless/02425_categorical_information_value_properties.reference b/tests/queries/0_stateless/02425_categorical_information_value_properties.reference new file mode 100644 index 00000000000..bc3af98b060 --- /dev/null +++ b/tests/queries/0_stateless/02425_categorical_information_value_properties.reference @@ -0,0 +1,12 @@ +0.347 +0.5 +0.347 +0.347 +[nan] +[nan] +[nan] +[nan] +[0] +\N +[nan] +[0,0] diff --git a/tests/queries/0_stateless/02425_categorical_information_value_properties.sql b/tests/queries/0_stateless/02425_categorical_information_value_properties.sql new file mode 100644 index 00000000000..81ed8400680 --- /dev/null +++ b/tests/queries/0_stateless/02425_categorical_information_value_properties.sql @@ -0,0 +1,14 @@ +SELECT round(arrayJoin(categoricalInformationValue(x.1, x.2)), 3) FROM (SELECT arrayJoin([(0, 0), (NULL, 2), (1, 0), (1, 1)]) AS x); +SELECT corr(c1, c2) FROM VALUES((0, 0), (NULL, 2), (1, 0), (1, 1)); +SELECT round(arrayJoin(categoricalInformationValue(c1, c2)), 3) FROM VALUES((0, 0), (NULL, 2), (1, 0), (1, 1)); +SELECT round(arrayJoin(categoricalInformationValue(c1, c2)), 3) FROM VALUES((0, 0), (NULL, 1), (1, 0), (1, 1)); +SELECT categoricalInformationValue(c1, c2) FROM VALUES((0, 0), (NULL, 1)); +SELECT categoricalInformationValue(c1, c2) FROM VALUES((NULL, 1)); -- { serverError 43 } +SELECT categoricalInformationValue(dummy, dummy); +SELECT categoricalInformationValue(dummy, dummy) WHERE 0; +SELECT categoricalInformationValue(c1, c2) FROM VALUES((toNullable(0), 0)); +SELECT groupUniqArray(*) FROM VALUES(toNullable(0)); +SELECT groupUniqArray(*) FROM VALUES(NULL); +SELECT categoricalInformationValue(c1, c2) FROM VALUES((NULL, NULL)); -- { serverError 43 } +SELECT categoricalInformationValue(c1, c2) FROM VALUES((0, 0), (NULL, 0)); +SELECT quantiles(0.5, 0.9)(c1) FROM VALUES(0::Nullable(UInt8));