From eb3422477ed2d257d5db1e06f4ab79b0dcf3929a Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 2 Aug 2020 04:29:52 +0300 Subject: [PATCH] Fix assert in *Map aggregate functions --- .../AggregateFunctionSumMap.h | 106 ++++++++++++++++-- src/Common/FieldVisitors.h | 84 -------------- .../0_stateless/01422_map_skip_null.reference | 7 ++ .../0_stateless/01422_map_skip_null.sql | 9 ++ 4 files changed, 114 insertions(+), 92 deletions(-) create mode 100644 tests/queries/0_stateless/01422_map_skip_null.reference create mode 100644 tests/queries/0_stateless/01422_map_skip_null.sql diff --git a/src/AggregateFunctions/AggregateFunctionSumMap.h b/src/AggregateFunctions/AggregateFunctionSumMap.h index ab17da1b490..c578b3f0334 100644 --- a/src/AggregateFunctions/AggregateFunctionSumMap.h +++ b/src/AggregateFunctions/AggregateFunctionSumMap.h @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -89,13 +90,15 @@ public: } else { - // No overflow, meaning we promote the types if necessary. - if (!value_type->canBePromoted()) - { - throw Exception{"Values for " + getName() + " are expected to be Numeric, Float or Decimal.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; - } + auto value_type_without_nullable = removeNullable(value_type); - result_type = value_type->promoteNumericType(); + // No overflow, meaning we promote the types if necessary. + if (!value_type_without_nullable->canBePromoted()) + throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Values for {} are expected to be Numeric, Float or Decimal, passed type {}", + getName(), value_type->getName()}; + + result_type = value_type_without_nullable->promoteNumericType(); } types.emplace_back(std::make_shared(result_type)); @@ -148,9 +151,10 @@ public: auto key = key_column.operator[](keys_vec_offset + i).get(); if (!keepKey(key)) - { continue; - } + + if (value.isNull()) + continue; typename std::decay_t::iterator it; if constexpr (IsDecimalNumber) @@ -375,6 +379,92 @@ public: bool keepKey(const T & key) const { return keys_to_keep.count(key); } }; + +/** Implements `Max` operation. + * Returns true if changed + */ +class FieldVisitorMax : public StaticVisitor +{ +private: + const Field & rhs; +public: + explicit FieldVisitorMax(const Field & rhs_) : rhs(rhs_) {} + + bool operator() (Null &) const { throw Exception("Cannot compare Nulls", ErrorCodes::LOGICAL_ERROR); } + bool operator() (Array &) const { throw Exception("Cannot compare Arrays", ErrorCodes::LOGICAL_ERROR); } + bool operator() (Tuple &) const { throw Exception("Cannot compare Tuples", ErrorCodes::LOGICAL_ERROR); } + bool operator() (AggregateFunctionStateData &) const { throw Exception("Cannot compare AggregateFunctionStates", ErrorCodes::LOGICAL_ERROR); } + + template + bool operator() (DecimalField & x) const + { + auto val = get>(rhs); + if (val > x) + { + x = val; + return true; + } + + return false; + } + + template + bool operator() (T & x) const + { + auto val = get(rhs); + if (val > x) + { + x = val; + return true; + } + + return false; + } +}; + +/** Implements `Min` operation. + * Returns true if changed + */ +class FieldVisitorMin : public StaticVisitor +{ +private: + const Field & rhs; +public: + explicit FieldVisitorMin(const Field & rhs_) : rhs(rhs_) {} + + bool operator() (Null &) const { throw Exception("Cannot compare Nulls", ErrorCodes::LOGICAL_ERROR); } + bool operator() (Array &) const { throw Exception("Cannot sum Arrays", ErrorCodes::LOGICAL_ERROR); } + bool operator() (Tuple &) const { throw Exception("Cannot sum Tuples", ErrorCodes::LOGICAL_ERROR); } + bool operator() (AggregateFunctionStateData &) const { throw Exception("Cannot sum AggregateFunctionStates", ErrorCodes::LOGICAL_ERROR); } + + template + bool operator() (DecimalField & x) const + { + auto val = get>(rhs); + if (val < x) + { + x = val; + return true; + } + + return false; + } + + template + bool operator() (T & x) const + { + auto val = get(rhs); + if (val < x) + { + x = val; + return true; + } + + return false; + } +}; + + template class AggregateFunctionMinMap final : public AggregateFunctionMapBase, FieldVisitorMin, true, tuple_argument> diff --git a/src/Common/FieldVisitors.h b/src/Common/FieldVisitors.h index 7fd4b748dbb..a749432500f 100644 --- a/src/Common/FieldVisitors.h +++ b/src/Common/FieldVisitors.h @@ -213,88 +213,4 @@ public: } }; -/** Implements `Max` operation. - * Returns true if changed - */ -class FieldVisitorMax : public StaticVisitor -{ -private: - const Field & rhs; -public: - explicit FieldVisitorMax(const Field & rhs_) : rhs(rhs_) {} - - bool operator() (Null &) const { throw Exception("Cannot compare Nulls", ErrorCodes::LOGICAL_ERROR); } - bool operator() (Array &) const { throw Exception("Cannot compare Arrays", ErrorCodes::LOGICAL_ERROR); } - bool operator() (Tuple &) const { throw Exception("Cannot compare Tuples", ErrorCodes::LOGICAL_ERROR); } - bool operator() (AggregateFunctionStateData &) const { throw Exception("Cannot compare AggregateFunctionStates", ErrorCodes::LOGICAL_ERROR); } - - template - bool operator() (DecimalField & x) const - { - auto val = get>(rhs); - if (val > x) - { - x = val; - return true; - } - - return false; - } - - template - bool operator() (T & x) const - { - auto val = get(rhs); - if (val > x) - { - x = val; - return true; - } - - return false; - } -}; - -/** Implements `Min` operation. - * Returns true if changed - */ -class FieldVisitorMin : public StaticVisitor -{ -private: - const Field & rhs; -public: - explicit FieldVisitorMin(const Field & rhs_) : rhs(rhs_) {} - - bool operator() (Null &) const { throw Exception("Cannot compare Nulls", ErrorCodes::LOGICAL_ERROR); } - bool operator() (Array &) const { throw Exception("Cannot sum Arrays", ErrorCodes::LOGICAL_ERROR); } - bool operator() (Tuple &) const { throw Exception("Cannot sum Tuples", ErrorCodes::LOGICAL_ERROR); } - bool operator() (AggregateFunctionStateData &) const { throw Exception("Cannot sum AggregateFunctionStates", ErrorCodes::LOGICAL_ERROR); } - - template - bool operator() (DecimalField & x) const - { - auto val = get>(rhs); - if (val < x) - { - x = val; - return true; - } - - return false; - } - - template - bool operator() (T & x) const - { - auto val = get(rhs); - if (val < x) - { - x = val; - return true; - } - - return false; - } -}; - } diff --git a/tests/queries/0_stateless/01422_map_skip_null.reference b/tests/queries/0_stateless/01422_map_skip_null.reference new file mode 100644 index 00000000000..7211e0ac75d --- /dev/null +++ b/tests/queries/0_stateless/01422_map_skip_null.reference @@ -0,0 +1,7 @@ +([],[]) +([],[]) +([],[]) +([2],[11]) +([2],[22]) +([2],[33]) +([2],[33]) diff --git a/tests/queries/0_stateless/01422_map_skip_null.sql b/tests/queries/0_stateless/01422_map_skip_null.sql new file mode 100644 index 00000000000..9af46758289 --- /dev/null +++ b/tests/queries/0_stateless/01422_map_skip_null.sql @@ -0,0 +1,9 @@ +select minMap(arrayJoin([([1], [null]), ([1], [null])])); +select maxMap(arrayJoin([([1], [null]), ([1], [null])])); +select sumMap(arrayJoin([([1], [null]), ([1], [null])])); -- { serverError 43 } +select sumMapWithOverflow(arrayJoin([([1], [null]), ([1], [null])])); + +select minMap(arrayJoin([([1, 2], [null, 11]), ([1, 2], [null, 22])])); +select maxMap(arrayJoin([([1, 2], [null, 11]), ([1, 2], [null, 22])])); +select sumMap(arrayJoin([([1, 2], [null, 11]), ([1, 2], [null, 22])])); +select sumMapWithOverflow(arrayJoin([([1, 2], [null, 11]), ([1, 2], [null, 22])]));