diff --git a/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp b/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp index 35654c08659..99ffc87e076 100644 --- a/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp +++ b/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.cpp @@ -1,12 +1,18 @@ -#include - +#include #include #include #include +#include +#include +#include +#include +#include +#include namespace DB { + struct Settings; namespace ErrorCodes @@ -15,6 +21,136 @@ namespace ErrorCodes extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } +/** The function takes arguments x1, x2, ... xn, y. All arguments are bool. + * x arguments represents the fact that some category is true. + * + * It calculates how many times y was true and how many times y was false when every n-th category was true + * and the total number of times y was true and false. + * + * So, the size of the state is (n + 1) * 2 cells. + */ +class AggregateFunctionCategoricalIV final : public IAggregateFunctionHelper +{ +private: + using Counter = UInt64; + size_t category_count; + + Counter & counter(AggregateDataPtr __restrict place, size_t i, bool what) const + { + return reinterpret_cast(place)[i * 2 + (what ? 1 : 0)]; + } + + const Counter & counter(ConstAggregateDataPtr __restrict place, size_t i, bool what) const + { + return reinterpret_cast(place)[i * 2 + (what ? 1 : 0)]; + } + +public: + AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_) : + IAggregateFunctionHelper{arguments_, params_}, + category_count{arguments_.size() - 1} + { + // notice: argument types has been checked before + } + + String getName() const override + { + return "categoricalInformationValue"; + } + + bool allocatesMemoryInArena() const override { return false; } + + void create(AggregateDataPtr __restrict place) const override + { + memset(place, 0, sizeOfData()); + } + + void destroy(AggregateDataPtr __restrict) const noexcept override + { + // nothing + } + + bool hasTrivialDestructor() const override + { + return true; + } + + size_t sizeOfData() const override + { + return sizeof(Counter) * (category_count + 1) * 2; + } + + size_t alignOfData() const override + { + return alignof(Counter); + } + + void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override + { + const auto * y_col = static_cast(columns[category_count]); + bool y = y_col->getData()[row_num]; + + for (size_t i = 0; i < category_count; ++i) + { + const auto * x_col = static_cast(columns[i]); + bool x = x_col->getData()[row_num]; + + if (x) + ++counter(place, i, y); + } + + ++counter(place, category_count, y); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override + { + for (size_t i = 0; i <= category_count; ++i) + { + counter(place, i, false) += counter(rhs, i, false); + counter(place, i, true) += counter(rhs, i, true); + } + } + + void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional /* version */) const override + { + buf.write(place, sizeOfData()); + } + + void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional /* version */, Arena *) const override + { + buf.read(place, sizeOfData()); + } + + DataTypePtr getReturnType() const override + { + return std::make_shared( + std::make_shared>()); + } + + void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override /// NOLINT + { + auto & col = static_cast(to); + auto & data_col = static_cast(col.getData()); + auto & offset_col = static_cast(col.getOffsetsColumn()); + + data_col.reserve(data_col.size() + category_count); + + Float64 sum_no = static_cast(counter(place, category_count, false)); + Float64 sum_yes = static_cast(counter(place, category_count, true)); + + for (size_t i = 0; i < category_count; ++i) + { + Float64 no = static_cast(counter(place, i, false)); + Float64 yes = static_cast(counter(place, i, true)); + + data_col.insertValue((no / sum_no - yes / sum_yes) * (log((no / sum_no) / (yes / sum_yes)))); + } + + offset_col.insertValue(data_col.size()); + } +}; + + namespace { @@ -39,16 +175,15 @@ AggregateFunctionPtr createAggregateFunctionCategoricalIV( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } - return std::make_shared>(arguments, params); + return std::make_shared(arguments, params); } } -void registerAggregateFunctionCategoricalIV( - AggregateFunctionFactory & factory -) +void registerAggregateFunctionCategoricalIV(AggregateFunctionFactory & factory) { - factory.registerFunction("categoricalInformationValue", createAggregateFunctionCategoricalIV); + AggregateFunctionProperties properties = { .returns_default_when_only_null = true }; + factory.registerFunction("categoricalInformationValue", { createAggregateFunctionCategoricalIV, properties }); } } diff --git a/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.h b/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.h deleted file mode 100644 index 0e0db27cf22..00000000000 --- a/src/AggregateFunctions/AggregateFunctionCategoricalInformationValue.h +++ /dev/null @@ -1,135 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include - - -namespace DB -{ -struct Settings; - -template -class AggregateFunctionCategoricalIV final : public IAggregateFunctionHelper> -{ -private: - size_t category_count; - -public: - AggregateFunctionCategoricalIV(const DataTypes & arguments_, const Array & params_) : - IAggregateFunctionHelper> {arguments_, params_}, - category_count {arguments_.size() - 1} - { - // notice: argument types has been checked before - } - - String getName() const override - { - return "categoricalInformationValue"; - } - - bool allocatesMemoryInArena() const override { return false; } - - void create(AggregateDataPtr __restrict place) const override - { - memset(place, 0, sizeOfData()); - } - - void destroy(AggregateDataPtr __restrict) const noexcept override - { - // nothing - } - - bool hasTrivialDestructor() const override - { - return true; - } - - size_t sizeOfData() const override - { - return sizeof(T) * (category_count + 1) * 2; - } - - size_t alignOfData() const override - { - return alignof(T); - } - - void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override - { - const auto * y_col = static_cast(columns[category_count]); - bool y = y_col->getData()[row_num]; - - for (size_t i : collections::range(0, category_count)) - { - const auto * x_col = static_cast(columns[i]); - bool x = x_col->getData()[row_num]; - - if (x) - reinterpret_cast(place)[i * 2 + size_t(y)] += 1; - } - - reinterpret_cast(place)[category_count * 2 + size_t(y)] += 1; - } - - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override - { - for (size_t i : collections::range(0, category_count + 1)) - { - reinterpret_cast(place)[i * 2] += reinterpret_cast(rhs)[i * 2]; - reinterpret_cast(place)[i * 2 + 1] += reinterpret_cast(rhs)[i * 2 + 1]; - } - } - - void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional /* version */) const override - { - buf.write(place, sizeOfData()); - } - - void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional /* version */, Arena *) const override - { - buf.read(place, sizeOfData()); - } - - DataTypePtr getReturnType() const override - { - return std::make_shared( - std::make_shared>() - ); - } - - void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override /// NOLINT - { - auto & col = static_cast(to); - auto & data_col = static_cast(col.getData()); - auto & offset_col = static_cast( - col.getOffsetsColumn() - ); - - data_col.reserve(data_col.size() + category_count); - - T sum_no = reinterpret_cast(place)[category_count * 2]; - T sum_yes = reinterpret_cast(place)[category_count * 2 + 1]; - - Float64 rev_no = 1. / sum_no; - Float64 rev_yes = 1. / sum_yes; - - for (size_t i : collections::range(0, category_count)) - { - T no = reinterpret_cast(place)[i * 2]; - T yes = reinterpret_cast(place)[i * 2 + 1]; - - data_col.insertValue((no * rev_no - yes * rev_yes) * (log(no * rev_no) - log(yes * rev_yes))); - } - - offset_col.insertValue(data_col.size()); - } -}; - -} diff --git a/src/DataTypes/Serializations/SerializationArray.cpp b/src/DataTypes/Serializations/SerializationArray.cpp index abd99038e98..75c4c01ac9a 100644 --- a/src/DataTypes/Serializations/SerializationArray.cpp +++ b/src/DataTypes/Serializations/SerializationArray.cpp @@ -174,7 +174,7 @@ namespace { auto current_offset = offsets_data[i]; sizes_data[i] = current_offset - prev_offset; - prev_offset = current_offset; + prev_offset = current_offset; } return column_sizes; diff --git a/tests/queries/0_stateless/02425_categorical_information_value_properties.reference b/tests/queries/0_stateless/02425_categorical_information_value_properties.reference new file mode 100644 index 00000000000..bc3af98b060 --- /dev/null +++ b/tests/queries/0_stateless/02425_categorical_information_value_properties.reference @@ -0,0 +1,12 @@ +0.347 +0.5 +0.347 +0.347 +[nan] +[nan] +[nan] +[nan] +[0] +\N +[nan] +[0,0] diff --git a/tests/queries/0_stateless/02425_categorical_information_value_properties.sql b/tests/queries/0_stateless/02425_categorical_information_value_properties.sql new file mode 100644 index 00000000000..81ed8400680 --- /dev/null +++ b/tests/queries/0_stateless/02425_categorical_information_value_properties.sql @@ -0,0 +1,14 @@ +SELECT round(arrayJoin(categoricalInformationValue(x.1, x.2)), 3) FROM (SELECT arrayJoin([(0, 0), (NULL, 2), (1, 0), (1, 1)]) AS x); +SELECT corr(c1, c2) FROM VALUES((0, 0), (NULL, 2), (1, 0), (1, 1)); +SELECT round(arrayJoin(categoricalInformationValue(c1, c2)), 3) FROM VALUES((0, 0), (NULL, 2), (1, 0), (1, 1)); +SELECT round(arrayJoin(categoricalInformationValue(c1, c2)), 3) FROM VALUES((0, 0), (NULL, 1), (1, 0), (1, 1)); +SELECT categoricalInformationValue(c1, c2) FROM VALUES((0, 0), (NULL, 1)); +SELECT categoricalInformationValue(c1, c2) FROM VALUES((NULL, 1)); -- { serverError 43 } +SELECT categoricalInformationValue(dummy, dummy); +SELECT categoricalInformationValue(dummy, dummy) WHERE 0; +SELECT categoricalInformationValue(c1, c2) FROM VALUES((toNullable(0), 0)); +SELECT groupUniqArray(*) FROM VALUES(toNullable(0)); +SELECT groupUniqArray(*) FROM VALUES(NULL); +SELECT categoricalInformationValue(c1, c2) FROM VALUES((NULL, NULL)); -- { serverError 43 } +SELECT categoricalInformationValue(c1, c2) FROM VALUES((0, 0), (NULL, 0)); +SELECT quantiles(0.5, 0.9)(c1) FROM VALUES(0::Nullable(UInt8));