diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp b/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp index e577df472c8..60712636562 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp @@ -51,6 +51,10 @@ public: if (!has_nullable_types) throw Exception("Aggregate function combinator 'Null' requires at least one argument to be Nullable", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if (nested_function) + if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params)) + return adapter; + /// Special case for 'count' function. It could be called with Nullable arguments /// - that means - count number of calls, when all arguments are not NULL. if (nested_function && nested_function->getName() == "count") @@ -71,9 +75,9 @@ public: else { if (return_type_is_nullable) - return std::make_shared>(nested_function, arguments, params); + return std::make_shared>(nested_function, arguments, params); else - return std::make_shared>(nested_function, arguments, params); + return std::make_shared>(nested_function, arguments, params); } } }; diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNull.h b/dbms/src/AggregateFunctions/AggregateFunctionNull.h index a5000f30cd5..a0fe96b6f62 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNull.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionNull.h @@ -204,12 +204,12 @@ public: }; -template -class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase> +template +class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase> { public: AggregateFunctionNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params) - : AggregateFunctionNullBase>(std::move(nested_function_), arguments, params), + : AggregateFunctionNullBase>(std::move(nested_function_), arguments, params), number_of_arguments(arguments.size()) { if (number_of_arguments == 1) @@ -233,7 +233,7 @@ public: if (is_nullable[i]) { const ColumnNullable & nullable_col = assert_cast(*columns[i]); - if (nullable_col.isNullAt(row_num)) + if (null_is_skipped && nullable_col.isNullAt(row_num)) { /// If at least one column has a null value in the current row, /// we don't process this row. diff --git a/dbms/src/AggregateFunctions/AggregateFunctionWindowFunnel.h b/dbms/src/AggregateFunctions/AggregateFunctionWindowFunnel.h index e8668b6172e..e19751d8daa 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionWindowFunnel.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionWindowFunnel.h @@ -11,7 +11,7 @@ #include #include -#include +#include namespace DB { @@ -232,6 +232,11 @@ public: return std::make_shared(); } + AggregateFunctionPtr getOwnNullAdapter(const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override + { + return std::make_shared>(nested_function, arguments, params); + } + void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override { bool has_event = false; diff --git a/dbms/src/AggregateFunctions/IAggregateFunction.h b/dbms/src/AggregateFunctions/IAggregateFunction.h index d7ccd4c206a..5cf8d90aaa3 100644 --- a/dbms/src/AggregateFunctions/IAggregateFunction.h +++ b/dbms/src/AggregateFunctions/IAggregateFunction.h @@ -31,6 +31,8 @@ using DataTypes = std::vector; using AggregateDataPtr = char *; using ConstAggregateDataPtr = const char *; +class IAggregateFunction; +using AggregateFunctionPtr = std::shared_ptr; /** Aggregate functions interface. * Instances of classes with this interface do not contain the data itself for aggregation, @@ -149,6 +151,11 @@ public: virtual void addBatchArray( size_t batch_size, AggregateDataPtr * places, size_t place_offset, const IColumn ** columns, const UInt64 * offsets, Arena * arena) const = 0; + virtual AggregateFunctionPtr getOwnNullAdapter(const AggregateFunctionPtr & /*nested_function*/, const DataTypes & /*arguments*/, const Array & /*params*/) const + { + return nullptr; + } + const DataTypes & getArgumentTypes() const { return argument_types; } const Array & getParameters() const { return parameters; } @@ -244,6 +251,4 @@ public: }; -using AggregateFunctionPtr = std::shared_ptr; - } diff --git a/dbms/tests/queries/0_stateless/00632_aggregation_window_funnel.reference b/dbms/tests/queries/0_stateless/00632_aggregation_window_funnel.reference index 1e572be797c..04851cc0e83 100644 --- a/dbms/tests/queries/0_stateless/00632_aggregation_window_funnel.reference +++ b/dbms/tests/queries/0_stateless/00632_aggregation_window_funnel.reference @@ -36,3 +36,23 @@ [4, 2] [5, 2] [6, 1] +[1, 2] +[2, 2] +[3, 0] +[4, 0] +[1, 2] +[2, 1] +[3, 0] +[4, 0] +[1, 0] +[2, 0] +[3, 1] +[4, 0] +[1, 0] +[2, 0] +[3, 1] +[4, 2] +[1, 0] +[2, 0] +[3, 1] +[4, 1] diff --git a/dbms/tests/queries/0_stateless/00632_aggregation_window_funnel.sql b/dbms/tests/queries/0_stateless/00632_aggregation_window_funnel.sql index a8a2f522be9..43ae0197782 100644 --- a/dbms/tests/queries/0_stateless/00632_aggregation_window_funnel.sql +++ b/dbms/tests/queries/0_stateless/00632_aggregation_window_funnel.sql @@ -64,3 +64,16 @@ select user, windowFunnel(86400)(dt, event='a', event='b', event='c') as s from select user, windowFunnel(86400, 'strict_order')(dt, event='a', event='b', event='c') as s from funnel_test_strict_order group by user order by user format JSONCompactEachRow; select user, windowFunnel(86400, 'strict', 'strict_order')(dt, event='a', event='b', event='c') as s from funnel_test_strict_order group by user order by user format JSONCompactEachRow; drop table funnel_test_strict_order; + +drop table if exists funnel_test_non_null; +create table funnel_test_non_null (`dt` DateTime, `u` int, `a` Nullable(String), `b` Nullable(String)) engine = MergeTree() partition by dt order by u; +insert into funnel_test_non_null values (1, 1, 'a1', 'b1') (2, 1, 'a2', 'b2'); +insert into funnel_test_non_null values (1, 2, 'a1', null) (2, 2, 'a2', null); +insert into funnel_test_non_null values (1, 3, null, null); +insert into funnel_test_non_null values (1, 4, null, 'b1') (2, 4, 'a2', null) (3, 4, null, 'b3'); +select u, windowFunnel(86400)(dt, a = 'a1', a = 'a2') as s from funnel_test_non_null group by u order by u format JSONCompactEachRow; +select u, windowFunnel(86400)(dt, a = 'a1', b = 'b2') as s from funnel_test_non_null group by u order by u format JSONCompactEachRow; +select u, windowFunnel(86400)(dt, a is null and b is null) as s from funnel_test_non_null group by u order by u format JSONCompactEachRow; +select u, windowFunnel(86400)(dt, a is null, b = 'b3') as s from funnel_test_non_null group by u order by u format JSONCompactEachRow; +select u, windowFunnel(86400, 'strict_order')(dt, a is null, b = 'b3') as s from funnel_test_non_null group by u order by u format JSONCompactEachRow; +drop table funnel_test_non_null;