diff --git a/src/AggregateFunctions/AggregateFunctionCount.cpp b/src/AggregateFunctions/AggregateFunctionCount.cpp index 6ea63bedaf0..05824947b87 100644 --- a/src/AggregateFunctions/AggregateFunctionCount.cpp +++ b/src/AggregateFunctions/AggregateFunctionCount.cpp @@ -8,7 +8,7 @@ namespace DB { AggregateFunctionPtr AggregateFunctionCount::getOwnNullAdapter( - const AggregateFunctionPtr &, const DataTypes & types, const Array & params) const + const AggregateFunctionPtr &, const DataTypes & types, const Array & params, const AggregateFunctionProperties & /*properties*/) const { return std::make_shared(types[0], params); } diff --git a/src/AggregateFunctions/AggregateFunctionCount.h b/src/AggregateFunctions/AggregateFunctionCount.h index 29c5de0021c..eb1583df92a 100644 --- a/src/AggregateFunctions/AggregateFunctionCount.h +++ b/src/AggregateFunctions/AggregateFunctionCount.h @@ -69,7 +69,7 @@ public: } AggregateFunctionPtr getOwnNullAdapter( - const AggregateFunctionPtr &, const DataTypes & types, const Array & params) const override; + const AggregateFunctionPtr &, const DataTypes & types, const Array & params, const AggregateFunctionProperties & /*properties*/) const override; }; diff --git a/src/AggregateFunctions/AggregateFunctionIf.cpp b/src/AggregateFunctions/AggregateFunctionIf.cpp index 19a175de911..47afddaf7ff 100644 --- a/src/AggregateFunctions/AggregateFunctionIf.cpp +++ b/src/AggregateFunctions/AggregateFunctionIf.cpp @@ -1,6 +1,7 @@ #include #include #include "registerAggregateFunctions.h" +#include "AggregateFunctionNull.h" namespace DB @@ -8,6 +9,7 @@ namespace DB namespace ErrorCodes { + extern const int LOGICAL_ERROR; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } @@ -40,6 +42,164 @@ public: } }; +/** There are two cases: for single argument and variadic. + * Code for single argument is much more efficient. + */ +template +class AggregateFunctionIfNullUnary final + : public AggregateFunctionNullBase> +{ +private: + size_t num_arguments; + + using Base = AggregateFunctionNullBase>; +public: + + String getName() const override + { + return Base::getName() + "If"; + } + + AggregateFunctionIfNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params) + : Base(std::move(nested_function_), arguments, params), num_arguments(arguments.size()) + { + if (num_arguments == 0) + throw Exception("Aggregate function " + getName() + " require at least one argument", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + } + + static inline bool singleFilter(const IColumn ** columns, size_t row_num, size_t num_arguments) + { + const IColumn * filter_column = columns[num_arguments - 1]; + if (const ColumnNullable * nullable_column = typeid_cast(filter_column)) + filter_column = nullable_column->getNestedColumnPtr().get(); + + return assert_cast(*filter_column).getData()[row_num]; + } + + void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + const ColumnNullable * column = assert_cast(columns[0]); + const IColumn * nested_column = &column->getNestedColumn(); + if (!column->isNullAt(row_num) && singleFilter(columns, row_num, num_arguments)) + { + this->setFlag(place); + this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); + } + } +}; + +template +class AggregateFunctionIfNullVariadic final + : public AggregateFunctionNullBase> +{ +public: + + String getName() const override + { + return Base::getName() + "If"; + } + + AggregateFunctionIfNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params) + : Base(std::move(nested_function_), arguments, params), number_of_arguments(arguments.size()) + { + if (number_of_arguments == 1) + throw Exception("Logical error: single argument is passed to AggregateFunctionIfNullVariadic", ErrorCodes::LOGICAL_ERROR); + + if (number_of_arguments > MAX_ARGS) + throw Exception("Maximum number of arguments for aggregate function with Nullable types is " + toString(size_t(MAX_ARGS)), + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + for (size_t i = 0; i < number_of_arguments; ++i) + is_nullable[i] = arguments[i]->isNullable(); + } + + static inline bool singleFilter(const IColumn ** columns, size_t row_num, size_t num_arguments) + { + return assert_cast(*columns[num_arguments - 1]).getData()[row_num]; + } + + void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + /// This container stores the columns we really pass to the nested function. + const IColumn * nested_columns[number_of_arguments]; + + for (size_t i = 0; i < number_of_arguments; ++i) + { + if (is_nullable[i]) + { + const ColumnNullable & nullable_col = assert_cast(*columns[i]); + if (null_is_skipped && nullable_col.isNullAt(row_num)) + { + /// If at least one column has a null value in the current row, + /// we don't process this row. + return; + } + nested_columns[i] = &nullable_col.getNestedColumn(); + } + else + nested_columns[i] = columns[i]; + } + + if (singleFilter(nested_columns, row_num, number_of_arguments)) + { + this->setFlag(place); + this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena); + } + } + +private: + using Base = AggregateFunctionNullBase>; + + enum { MAX_ARGS = 8 }; + size_t number_of_arguments = 0; + std::array is_nullable; /// Plain array is better than std::vector due to one indirection less. +}; + + +AggregateFunctionPtr AggregateFunctionIf::getOwnNullAdapter( + const AggregateFunctionPtr & nested_function, const DataTypes & arguments, + const Array & params, const AggregateFunctionProperties & properties) const +{ + bool return_type_is_nullable = !properties.returns_default_when_only_null && getReturnType()->canBeInsideNullable(); + size_t nullable_size = std::count_if(arguments.begin(), arguments.end(), [](const auto & element) { return element->isNullable(); }); + return_type_is_nullable &= nullable_size != 1 || !arguments.back()->isNullable(); /// If only condition is nullable. we should non-nullable type. + bool serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null; + + if (arguments.size() <= 2 && arguments.front()->isNullable()) + { + if (return_type_is_nullable) + { + return std::make_shared>(nested_func, arguments, params); + } + else + { + if (serialize_flag) + return std::make_shared>(nested_func, arguments, params); + else + return std::make_shared>(nested_func, arguments, params); + } + } + else + { + if (return_type_is_nullable) + { + return std::make_shared>(nested_function, arguments, params); + } + else + { + if (serialize_flag) + return std::make_shared>(nested_function, arguments, params); + else + return std::make_shared>(nested_function, arguments, params); + } + } +} + void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory & factory) { factory.registerCombinator(std::make_shared()); diff --git a/src/AggregateFunctions/AggregateFunctionIf.h b/src/AggregateFunctions/AggregateFunctionIf.h index f04450c9142..d5d2b9be0dd 100644 --- a/src/AggregateFunctions/AggregateFunctionIf.h +++ b/src/AggregateFunctions/AggregateFunctionIf.h @@ -109,6 +109,10 @@ public: { return nested_func->isState(); } + + AggregateFunctionPtr getOwnNullAdapter( + const AggregateFunctionPtr & nested_function, const DataTypes & arguments, + const Array & params, const AggregateFunctionProperties & properties) const override; }; } diff --git a/src/AggregateFunctions/AggregateFunctionNull.cpp b/src/AggregateFunctions/AggregateFunctionNull.cpp index 5e0d6ee6e21..f584ae1f34c 100644 --- a/src/AggregateFunctions/AggregateFunctionNull.cpp +++ b/src/AggregateFunctions/AggregateFunctionNull.cpp @@ -72,7 +72,7 @@ public: assert(nested_function); - if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params)) + if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params, properties)) return adapter; /// If applied to aggregate function with -State combinator, we apply -Null combinator to it's nested_function instead of itself. diff --git a/src/AggregateFunctions/AggregateFunctionWindowFunnel.h b/src/AggregateFunctions/AggregateFunctionWindowFunnel.h index 3297819a9ff..fe45fec4b76 100644 --- a/src/AggregateFunctions/AggregateFunctionWindowFunnel.h +++ b/src/AggregateFunctions/AggregateFunctionWindowFunnel.h @@ -241,7 +241,8 @@ public: } AggregateFunctionPtr getOwnNullAdapter( - const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override + const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params, + const AggregateFunctionProperties & /*properties*/) const override { return std::make_shared>(nested_function, arguments, params); } diff --git a/src/AggregateFunctions/IAggregateFunction.h b/src/AggregateFunctions/IAggregateFunction.h index 4f9552d2345..b5a15eb8cbe 100644 --- a/src/AggregateFunctions/IAggregateFunction.h +++ b/src/AggregateFunctions/IAggregateFunction.h @@ -33,6 +33,7 @@ using ConstAggregateDataPtr = const char *; class IAggregateFunction; using AggregateFunctionPtr = std::shared_ptr; +struct AggregateFunctionProperties; /** Aggregate functions interface. * Instances of classes with this interface do not contain the data itself for aggregation, @@ -185,7 +186,8 @@ public: * arguments and params are for nested_function. */ virtual AggregateFunctionPtr getOwnNullAdapter( - const AggregateFunctionPtr & /*nested_function*/, const DataTypes & /*arguments*/, const Array & /*params*/) const + const AggregateFunctionPtr & /*nested_function*/, const DataTypes & /*arguments*/, + const Array & /*params*/, const AggregateFunctionProperties & /*properties*/) const { return nullptr; } diff --git a/tests/queries/0_stateless/01455_nullable_type_with_if_agg_combinator.reference b/tests/queries/0_stateless/01455_nullable_type_with_if_agg_combinator.reference new file mode 100644 index 00000000000..77f38b722ce --- /dev/null +++ b/tests/queries/0_stateless/01455_nullable_type_with_if_agg_combinator.reference @@ -0,0 +1,3 @@ +\N Nullable(UInt8) +\N Nullable(UInt8) +0 UInt8 diff --git a/tests/queries/0_stateless/01455_nullable_type_with_if_agg_combinator.sql b/tests/queries/0_stateless/01455_nullable_type_with_if_agg_combinator.sql new file mode 100644 index 00000000000..852660117f5 --- /dev/null +++ b/tests/queries/0_stateless/01455_nullable_type_with_if_agg_combinator.sql @@ -0,0 +1,6 @@ +-- Value nullable +SELECT anyIf(CAST(number, 'Nullable(UInt8)'), number = 3) AS a, toTypeName(a) FROM numbers(2); +-- Value and condition nullable +SELECT anyIf(number, number = 3) AS a, toTypeName(a) FROM (SELECT CAST(number, 'Nullable(UInt8)') AS number FROM numbers(2)); +-- Condition nullable +SELECT anyIf(CAST(number, 'UInt8'), number = 3) AS a, toTypeName(a) FROM (SELECT CAST(number, 'Nullable(UInt8)') AS number FROM numbers(2));