diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp b/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp index bfa2bcb032c..abcc104d520 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionNull.cpp @@ -5,12 +5,20 @@ namespace DB AggregateFunctionPtr createAggregateFunctionNullUnary(AggregateFunctionPtr & nested) { - return std::make_shared(nested); + const DataTypePtr & nested_return_type = nested->getReturnType(); + if (nested_return_type && !nested_return_type->canBeInsideNullable()) + return std::make_shared>(nested); + else + return std::make_shared>(nested); } AggregateFunctionPtr createAggregateFunctionNullVariadic(AggregateFunctionPtr & nested) { - return std::make_shared(nested); + const DataTypePtr & nested_return_type = nested->getReturnType(); + if (nested_return_type && !nested_return_type->canBeInsideNullable()) + return std::make_shared>(nested); + else + return std::make_shared>(nested); } } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionNull.h b/dbms/src/AggregateFunctions/AggregateFunctionNull.h index 4d0c5e5f746..763097fb7c8 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionNull.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionNull.h @@ -24,6 +24,10 @@ namespace ErrorCodes /// at least one nullable argument. It implements the logic according to which any /// row that contains at least one NULL is skipped. +/// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter. +/// true - return NULL; false - return value from empty aggregation state of nested function. + +template class AggregateFunctionNullBase : public IAggregateFunction { protected: @@ -36,27 +40,29 @@ protected: static AggregateDataPtr nestedPlace(AggregateDataPtr place) noexcept { - return place + 1; + return place + (result_is_nullable ? 1 : 0); } static ConstAggregateDataPtr nestedPlace(ConstAggregateDataPtr place) noexcept { - return place + 1; + return place + (result_is_nullable ? 1 : 0); } static void initFlag(AggregateDataPtr place) noexcept { - place[0] = 0; + if (result_is_nullable) + place[0] = 0; } static void setFlag(AggregateDataPtr place) noexcept { - place[0] = 1; + if (result_is_nullable) + place[0] = 1; } static bool getFlag(ConstAggregateDataPtr place) noexcept { - return place[0]; + return result_is_nullable ? place[0] : 1; } public: @@ -78,7 +84,9 @@ public: DataTypePtr getReturnType() const override { - return std::make_shared(nested_function->getReturnType()); + return result_is_nullable + ? std::make_shared(nested_function->getReturnType()) + : nested_function->getReturnType(); } void create(AggregateDataPtr place) const override @@ -109,7 +117,7 @@ public: void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override { - if (getFlag(rhs)) + if (result_is_nullable && getFlag(rhs)) setFlag(place); nested_function->merge(nestedPlace(place), nestedPlace(rhs), arena); @@ -118,15 +126,17 @@ public: void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { bool flag = getFlag(place); - writeBinary(flag, buf); + if (result_is_nullable) + writeBinary(flag, buf); if (flag) nested_function->serialize(nestedPlace(place), buf); } void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override { - bool flag; - readBinary(flag, buf); + bool flag = 1; + if (result_is_nullable) + readBinary(flag, buf); if (flag) { setFlag(place); @@ -136,15 +146,22 @@ public: void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override { - ColumnNullable & to_concrete = static_cast(to); - if (getFlag(place)) + if (result_is_nullable) { - nested_function->insertResultInto(nestedPlace(place), *to_concrete.getNestedColumn()); - to_concrete.getNullMap().push_back(0); + ColumnNullable & to_concrete = static_cast(to); + if (getFlag(place)) + { + nested_function->insertResultInto(nestedPlace(place), *to_concrete.getNestedColumn()); + to_concrete.getNullMap().push_back(0); + } + else + { + to_concrete.insertDefault(); + } } else { - to_concrete.insertDefault(); + nested_function->insertResultInto(nestedPlace(place), to); } } @@ -165,10 +182,11 @@ public: /** There are two cases: for single argument and variadic. * Code for single argument is much more efficient. */ -class AggregateFunctionNullUnary final : public AggregateFunctionNullBase +template +class AggregateFunctionNullUnary final : public AggregateFunctionNullBase { public: - using AggregateFunctionNullBase::AggregateFunctionNullBase; + using AggregateFunctionNullBase::AggregateFunctionNullBase; void setArguments(const DataTypes & arguments) override { @@ -178,7 +196,7 @@ public: if (!arguments.front()->isNullable()) throw Exception("Logical error: not nullable data type is passed to AggregateFunctionNullUnary", ErrorCodes::LOGICAL_ERROR); - nested_function->setArguments({static_cast(*arguments.front()).getNestedType()}); + this->nested_function->setArguments({static_cast(*arguments.front()).getNestedType()}); } void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override @@ -186,9 +204,9 @@ public: const ColumnNullable * column = static_cast(columns[0]); if (!column->isNullAt(row_num)) { - setFlag(place); + this->setFlag(place); const IColumn * nested_column = column->getNestedColumn().get(); - nested_function->add(nestedPlace(place), &nested_column, row_num, arena); + this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena); } } @@ -198,17 +216,18 @@ public: return static_cast(*that).add(place, columns, row_num, arena); } - AddFunc getAddressOfAddFunction() const override + IAggregateFunction::AddFunc getAddressOfAddFunction() const override { return &addFree; } }; -class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase +template +class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase { public: - using AggregateFunctionNullBase::AggregateFunctionNullBase; + using AggregateFunctionNullBase::AggregateFunctionNullBase; void setArguments(const DataTypes & arguments) override { @@ -238,7 +257,7 @@ public: nested_args[i] = arguments[i]; } - nested_function->setArguments(nested_args); + this->nested_function->setArguments(nested_args); } void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override @@ -263,13 +282,13 @@ public: nested_columns[i] = columns[i]; } - setFlag(place); - nested_function->add(nestedPlace(place), nested_columns, row_num, arena); + this->setFlag(place); + this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena); } bool allocatesMemoryInArena() const override { - return nested_function->allocatesMemoryInArena(); + return this->nested_function->allocatesMemoryInArena(); } static void addFree(const IAggregateFunction * that, AggregateDataPtr place, @@ -278,7 +297,7 @@ public: return static_cast(*that).add(place, columns, row_num, arena); } - AddFunc getAddressOfAddFunction() const override + IAggregateFunction::AddFunc getAddressOfAddFunction() const override { return &addFree; } diff --git a/dbms/tests/queries/0_stateless/00525_aggregate_functions_of_nullable_that_return_non_nullable.reference b/dbms/tests/queries/0_stateless/00525_aggregate_functions_of_nullable_that_return_non_nullable.reference new file mode 100644 index 00000000000..0d76495acb3 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00525_aggregate_functions_of_nullable_that_return_non_nullable.reference @@ -0,0 +1,2 @@ +1 [1,2] Array(UInt8) 1.5 Nullable(Float64) +2 [] Array(UInt8) \N Nullable(Float64) diff --git a/dbms/tests/queries/0_stateless/00525_aggregate_functions_of_nullable_that_return_non_nullable.sql b/dbms/tests/queries/0_stateless/00525_aggregate_functions_of_nullable_that_return_non_nullable.sql new file mode 100644 index 00000000000..598de2e74ea --- /dev/null +++ b/dbms/tests/queries/0_stateless/00525_aggregate_functions_of_nullable_that_return_non_nullable.sql @@ -0,0 +1 @@ +SELECT k, groupArray(x) AS res1, toTypeName(res1), avg(x) AS res2, toTypeName(res2) FROM (SELECT 1 AS k, arrayJoin([1, NULL, 2]) AS x UNION ALL SELECT 2 AS k, CAST(arrayJoin([NULL, NULL]) AS Nullable(UInt8)) AS x) GROUP BY k ORDER BY k;