diff --git a/src/AggregateFunctions/AggregateFunctionForEach.h b/src/AggregateFunctions/AggregateFunctionForEach.h index 19f2994d3f1..ee4a168cceb 100644 --- a/src/AggregateFunctions/AggregateFunctionForEach.h +++ b/src/AggregateFunctions/AggregateFunctionForEach.h @@ -247,6 +247,11 @@ public: { return true; } + + bool isState() const override + { + return nested_func->isState(); + } }; diff --git a/src/AggregateFunctions/AggregateFunctionNull.cpp b/src/AggregateFunctions/AggregateFunctionNull.cpp index b8fbad53350..b65b4aba447 100644 --- a/src/AggregateFunctions/AggregateFunctionNull.cpp +++ b/src/AggregateFunctions/AggregateFunctionNull.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include "registerAggregateFunctions.h" @@ -71,6 +72,19 @@ public: if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params)) return adapter; + /// If applied to aggregate function with -State combinator, we apply -Null combinator to it's nested_function instead of itself. + /// Because Nullable AggregateFunctionState does not make sense and ruins the logic of managing aggregate function states. + + if (const AggregateFunctionState * function_state = typeid_cast(nested_function.get())) + { + auto transformed_nested_function = transformAggregateFunction(function_state->getNestedFunction(), properties, arguments, params); + + return std::make_shared( + transformed_nested_function, + transformed_nested_function->getArgumentTypes(), + transformed_nested_function->getParameters()); + } + bool return_type_is_nullable = !properties.returns_default_when_only_null && nested_function->getReturnType()->canBeInsideNullable(); bool serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null; diff --git a/src/AggregateFunctions/IAggregateFunction.h b/src/AggregateFunctions/IAggregateFunction.h index eb9c560af98..25d8580a923 100644 --- a/src/AggregateFunctions/IAggregateFunction.h +++ b/src/AggregateFunctions/IAggregateFunction.h @@ -122,8 +122,9 @@ public: throw Exception("Method predictValues is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED); } - /** Returns true for aggregate functions of type -State. + /** Returns true for aggregate functions of type -State * They are executed as other aggregate functions, but not finalized (return an aggregation state that can be combined with another). + * Also returns true when the final value of this aggregate function contains State of other aggregate function inside. */ virtual bool isState() const { return false; } diff --git a/src/Columns/ColumnAggregateFunction.cpp b/src/Columns/ColumnAggregateFunction.cpp index 915dd7530c4..4b9dcc8d04e 100644 --- a/src/Columns/ColumnAggregateFunction.cpp +++ b/src/Columns/ColumnAggregateFunction.cpp @@ -85,6 +85,20 @@ void ColumnAggregateFunction::addArena(ConstArenaPtr arena_) foreign_arenas.push_back(arena_); } +namespace +{ + +ConstArenas concatArenas(const ConstArenas & array, ConstArenaPtr arena) +{ + ConstArenas result = array; + if (arena) + result.push_back(std::move(arena)); + + return result; +} + +} + MutableColumnPtr ColumnAggregateFunction::convertToValues(MutableColumnPtr column) { /** If the aggregate function returns an unfinalized/unfinished state, @@ -121,19 +135,27 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues(MutableColumnPtr colum auto & func = column_aggregate_func.func; auto & data = column_aggregate_func.data; - if (const AggregateFunctionState *function_state = typeid_cast(func.get())) - { - auto res = column_aggregate_func.createView(); - res->set(function_state->getNestedFunction()); - res->data.assign(data.begin(), data.end()); - return res; - } - + /// insertResultInto may invalidate states, so we must unshare ownership of them column_aggregate_func.ensureOwnership(); MutableColumnPtr res = func->getReturnType()->createColumn(); res->reserve(data.size()); + /// If there are references to states in final column, we must hold their ownership + /// by holding arenas and source. + + auto callback = [&](auto & subcolumn) + { + if (auto * aggregate_subcolumn = typeid_cast(subcolumn.get())) + { + aggregate_subcolumn->foreign_arenas = concatArenas(column_aggregate_func.foreign_arenas, column_aggregate_func.my_arena); + aggregate_subcolumn->src = column_aggregate_func.getPtr(); + } + }; + + callback(res); + res->forEachSubcolumn(callback); + for (auto * val : data) func->insertResultInto(val, *res, &column_aggregate_func.createOrGetArena()); @@ -629,20 +651,6 @@ void ColumnAggregateFunction::getExtremes(Field & min, Field & max) const max = serialized; } -namespace -{ - -ConstArenas concatArenas(const ConstArenas & array, ConstArenaPtr arena) -{ - ConstArenas result = array; - if (arena) - result.push_back(std::move(arena)); - - return result; -} - -} - ColumnAggregateFunction::MutablePtr ColumnAggregateFunction::createView() const { auto res = create(func, concatArenas(foreign_arenas, my_arena)); diff --git a/src/Functions/FunctionsBitmap.cpp b/src/Functions/FunctionsBitmap.cpp index c94566b04b0..72652288872 100644 --- a/src/Functions/FunctionsBitmap.cpp +++ b/src/Functions/FunctionsBitmap.cpp @@ -1,7 +1,6 @@ #include -// TODO include this last because of a broken roaring header. See the comment -// inside. +// TODO include this last because of a broken roaring header. See the comment inside. #include diff --git a/tests/queries/0_stateless/01380_nullable_state.reference b/tests/queries/0_stateless/01380_nullable_state.reference new file mode 100644 index 00000000000..f87ff0a3f1f --- /dev/null +++ b/tests/queries/0_stateless/01380_nullable_state.reference @@ -0,0 +1,64 @@ +0100012CCBC234 + +0100012CCBC234 +--- +0100012CCBC234 + +0100012CCBC234 +--- +0100012CCBC234 + +0100012CCBC234 +--- +0100012CCBC234 + +0100012CCBC234 +--- +0100012CCBC234 + +0100012CCBC234 +--- +0100012CCBC234 + +0100012CCBC234 +--- +1 + +1 +--- +0 1 +1 1 +2 1 +3 1 +4 1 + +0 1 +--- +0 1 +1 1 +2 1 +3 1 +4 1 + +0 1 +--- +0 [0] +1 [0] +2 [0] +3 [0] +4 [0] + +0 [0] +--- +0 [0] +1 [0] +2 [0] +3 [0] +4 [0] + +\N [0] +--- +0100012CCBC234 +--- +0100012CCBC234 +--- diff --git a/tests/queries/0_stateless/01380_nullable_state.sql b/tests/queries/0_stateless/01380_nullable_state.sql new file mode 100644 index 00000000000..6841a6ce636 --- /dev/null +++ b/tests/queries/0_stateless/01380_nullable_state.sql @@ -0,0 +1,26 @@ +SELECT hex(toString(uniqState(toNullable(1)))) WITH TOTALS; +SELECT '---'; +SELECT hex(toString(uniqState(x))) FROM (SELECT toNullable(1) AS x) WITH TOTALS; +SELECT '---'; +SELECT DISTINCT hex(toString(uniqState(x))) FROM (SELECT materialize(1) AS k, toNullable(1) AS x FROM numbers(1)) GROUP BY k WITH TOTALS ORDER BY k; +SELECT '---'; +SELECT DISTINCT hex(toString(uniqState(x))) FROM (SELECT materialize(1) AS k, toNullable(1) AS x FROM numbers(10)) GROUP BY k WITH TOTALS ORDER BY k; +SELECT '---'; +SELECT DISTINCT hex(toString(uniqState(x))) FROM (SELECT intDiv(number, 3) AS k, toNullable(1) AS x FROM numbers(10)) GROUP BY k WITH TOTALS ORDER BY k; +SELECT '---'; +SELECT DISTINCT hex(toString(uniqState(x))) FROM (SELECT intDiv(number, 3) AS k, toNullable(1) AS x FROM system.numbers LIMIT 100000) GROUP BY k WITH TOTALS ORDER BY k; +SELECT '---'; +SELECT DISTINCT arrayUniq(finalizeAggregation(groupArrayState(x))) FROM (SELECT intDiv(number, 3) AS k, toNullable(1) AS x FROM system.numbers LIMIT 100000) GROUP BY k WITH TOTALS ORDER BY k; +SELECT '---'; +SELECT k, finalizeAggregation(uniqState(x)) FROM (SELECT intDiv(number, 3) AS k, toNullable(1) AS x FROM system.numbers LIMIT 100000) GROUP BY k WITH TOTALS ORDER BY k LIMIT 5; +SELECT '---'; +SELECT k, finalizeAggregation(uniqState(x)) FROM (WITH toNullable(number = 3 ? 3 : 1) AS d SELECT intDiv(number, 3) AS k, number % d AS x FROM system.numbers LIMIT 100000) GROUP BY k WITH TOTALS ORDER BY k LIMIT 5; +SELECT '---'; +SELECT k, finalizeAggregation(quantilesTimingState(0.5)(x)) FROM (WITH toNullable(number = 3 ? 3 : 1) AS d SELECT intDiv(number, 3) AS k, number % d AS x FROM system.numbers LIMIT 100000) GROUP BY k WITH TOTALS ORDER BY k LIMIT 5; +SELECT '---'; +SELECT k, finalizeAggregation(quantilesTimingState(0.5)(x)) FROM (SELECT intDiv(number, if(number = 9223372036854775807, -2, if(number = 3, number = if(number = 1, NULL, 3), 1)) AS d) AS k, number % d AS x FROM system.numbers LIMIT 100000) GROUP BY k WITH TOTALS ORDER BY k ASC LIMIT 5; +SELECT '---'; +SELECT DISTINCT hex(toString(uniqState(x))) FROM (SELECT materialize(1) AS k, toNullable(1) AS x FROM numbers(1)) GROUP BY k WITH ROLLUP ORDER BY k; +SELECT '---'; +SELECT DISTINCT hex(toString(uniqState(x))) FROM (SELECT materialize(1) AS k, toNullable(1) AS x FROM numbers(1)) GROUP BY k WITH CUBE ORDER BY k; +SELECT '---'; diff --git a/tests/queries/0_stateless/01381_for_each_with_states.reference b/tests/queries/0_stateless/01381_for_each_with_states.reference new file mode 100644 index 00000000000..3d1732f9d5d --- /dev/null +++ b/tests/queries/0_stateless/01381_for_each_with_states.reference @@ -0,0 +1,16 @@ +5B27015C30012CCBC234272C27015C305C30275D +02000000000000000100012CCBC234010000 +['0100012CCBC234','010000'] +[1,0] +5B27015C30012CCBC234272C27015C305C30275D + +5B27015C30012CCBC234272C27015C305C30275D +02000000000000000100012CCBC234010000 + +02000000000000000100012CCBC234010000 +['0100012CCBC234','010000'] + +['0100012CCBC234','010000'] +[1,0] + +[1,0] diff --git a/tests/queries/0_stateless/01381_for_each_with_states.sql b/tests/queries/0_stateless/01381_for_each_with_states.sql new file mode 100644 index 00000000000..7286ef2cb27 --- /dev/null +++ b/tests/queries/0_stateless/01381_for_each_with_states.sql @@ -0,0 +1,9 @@ +SELECT hex(toString(uniqStateForEach([1, NULL]))); +SELECT hex(toString(uniqStateForEachState([1, NULL]))); +SELECT arrayMap(x -> hex(toString(x)), finalizeAggregation(uniqStateForEachState([1, NULL]))); +SELECT arrayMap(x -> finalizeAggregation(x), finalizeAggregation(uniqStateForEachState([1, NULL]))); + +SELECT hex(toString(uniqStateForEach([1, NULL]))) WITH TOTALS; +SELECT hex(toString(uniqStateForEachState([1, NULL]))) WITH TOTALS; +SELECT arrayMap(x -> hex(toString(x)), finalizeAggregation(uniqStateForEachState([1, NULL]))) WITH TOTALS; +SELECT arrayMap(x -> finalizeAggregation(x), finalizeAggregation(uniqStateForEachState([1, NULL]))) WITH TOTALS;