Fix assert in *Map aggregate functions

This commit is contained in:
Alexey Milovidov 2020-08-02 04:29:52 +03:00
parent 0adf8c723e
commit eb3422477e
4 changed files with 114 additions and 92 deletions

View File

@ -5,6 +5,7 @@
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeNullable.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnTuple.h>
@ -89,13 +90,15 @@ public:
}
else
{
// No overflow, meaning we promote the types if necessary.
if (!value_type->canBePromoted())
{
throw Exception{"Values for " + getName() + " are expected to be Numeric, Float or Decimal.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
auto value_type_without_nullable = removeNullable(value_type);
result_type = value_type->promoteNumericType();
// No overflow, meaning we promote the types if necessary.
if (!value_type_without_nullable->canBePromoted())
throw Exception{ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Values for {} are expected to be Numeric, Float or Decimal, passed type {}",
getName(), value_type->getName()};
result_type = value_type_without_nullable->promoteNumericType();
}
types.emplace_back(std::make_shared<DataTypeArray>(result_type));
@ -148,9 +151,10 @@ public:
auto key = key_column.operator[](keys_vec_offset + i).get<T>();
if (!keepKey(key))
{
continue;
}
if (value.isNull())
continue;
typename std::decay_t<decltype(merged_maps)>::iterator it;
if constexpr (IsDecimalNumber<T>)
@ -375,6 +379,92 @@ public:
bool keepKey(const T & key) const { return keys_to_keep.count(key); }
};
/** Implements `Max` operation.
* Returns true if changed
*/
class FieldVisitorMax : public StaticVisitor<bool>
{
private:
const Field & rhs;
public:
explicit FieldVisitorMax(const Field & rhs_) : rhs(rhs_) {}
bool operator() (Null &) const { throw Exception("Cannot compare Nulls", ErrorCodes::LOGICAL_ERROR); }
bool operator() (Array &) const { throw Exception("Cannot compare Arrays", ErrorCodes::LOGICAL_ERROR); }
bool operator() (Tuple &) const { throw Exception("Cannot compare Tuples", ErrorCodes::LOGICAL_ERROR); }
bool operator() (AggregateFunctionStateData &) const { throw Exception("Cannot compare AggregateFunctionStates", ErrorCodes::LOGICAL_ERROR); }
template <typename T>
bool operator() (DecimalField<T> & x) const
{
auto val = get<DecimalField<T>>(rhs);
if (val > x)
{
x = val;
return true;
}
return false;
}
template <typename T>
bool operator() (T & x) const
{
auto val = get<T>(rhs);
if (val > x)
{
x = val;
return true;
}
return false;
}
};
/** Implements `Min` operation.
* Returns true if changed
*/
class FieldVisitorMin : public StaticVisitor<bool>
{
private:
const Field & rhs;
public:
explicit FieldVisitorMin(const Field & rhs_) : rhs(rhs_) {}
bool operator() (Null &) const { throw Exception("Cannot compare Nulls", ErrorCodes::LOGICAL_ERROR); }
bool operator() (Array &) const { throw Exception("Cannot sum Arrays", ErrorCodes::LOGICAL_ERROR); }
bool operator() (Tuple &) const { throw Exception("Cannot sum Tuples", ErrorCodes::LOGICAL_ERROR); }
bool operator() (AggregateFunctionStateData &) const { throw Exception("Cannot sum AggregateFunctionStates", ErrorCodes::LOGICAL_ERROR); }
template <typename T>
bool operator() (DecimalField<T> & x) const
{
auto val = get<DecimalField<T>>(rhs);
if (val < x)
{
x = val;
return true;
}
return false;
}
template <typename T>
bool operator() (T & x) const
{
auto val = get<T>(rhs);
if (val < x)
{
x = val;
return true;
}
return false;
}
};
template <typename T, bool tuple_argument>
class AggregateFunctionMinMap final :
public AggregateFunctionMapBase<T, AggregateFunctionMinMap<T, tuple_argument>, FieldVisitorMin, true, tuple_argument>

View File

@ -213,88 +213,4 @@ public:
}
};
/** Implements `Max` operation.
* Returns true if changed
*/
class FieldVisitorMax : public StaticVisitor<bool>
{
private:
const Field & rhs;
public:
explicit FieldVisitorMax(const Field & rhs_) : rhs(rhs_) {}
bool operator() (Null &) const { throw Exception("Cannot compare Nulls", ErrorCodes::LOGICAL_ERROR); }
bool operator() (Array &) const { throw Exception("Cannot compare Arrays", ErrorCodes::LOGICAL_ERROR); }
bool operator() (Tuple &) const { throw Exception("Cannot compare Tuples", ErrorCodes::LOGICAL_ERROR); }
bool operator() (AggregateFunctionStateData &) const { throw Exception("Cannot compare AggregateFunctionStates", ErrorCodes::LOGICAL_ERROR); }
template <typename T>
bool operator() (DecimalField<T> & x) const
{
auto val = get<DecimalField<T>>(rhs);
if (val > x)
{
x = val;
return true;
}
return false;
}
template <typename T>
bool operator() (T & x) const
{
auto val = get<T>(rhs);
if (val > x)
{
x = val;
return true;
}
return false;
}
};
/** Implements `Min` operation.
* Returns true if changed
*/
class FieldVisitorMin : public StaticVisitor<bool>
{
private:
const Field & rhs;
public:
explicit FieldVisitorMin(const Field & rhs_) : rhs(rhs_) {}
bool operator() (Null &) const { throw Exception("Cannot compare Nulls", ErrorCodes::LOGICAL_ERROR); }
bool operator() (Array &) const { throw Exception("Cannot sum Arrays", ErrorCodes::LOGICAL_ERROR); }
bool operator() (Tuple &) const { throw Exception("Cannot sum Tuples", ErrorCodes::LOGICAL_ERROR); }
bool operator() (AggregateFunctionStateData &) const { throw Exception("Cannot sum AggregateFunctionStates", ErrorCodes::LOGICAL_ERROR); }
template <typename T>
bool operator() (DecimalField<T> & x) const
{
auto val = get<DecimalField<T>>(rhs);
if (val < x)
{
x = val;
return true;
}
return false;
}
template <typename T>
bool operator() (T & x) const
{
auto val = get<T>(rhs);
if (val < x)
{
x = val;
return true;
}
return false;
}
};
}

View File

@ -0,0 +1,7 @@
([],[])
([],[])
([],[])
([2],[11])
([2],[22])
([2],[33])
([2],[33])

View File

@ -0,0 +1,9 @@
select minMap(arrayJoin([([1], [null]), ([1], [null])]));
select maxMap(arrayJoin([([1], [null]), ([1], [null])]));
select sumMap(arrayJoin([([1], [null]), ([1], [null])])); -- { serverError 43 }
select sumMapWithOverflow(arrayJoin([([1], [null]), ([1], [null])]));
select minMap(arrayJoin([([1, 2], [null, 11]), ([1, 2], [null, 22])]));
select maxMap(arrayJoin([([1, 2], [null, 11]), ([1, 2], [null, 22])]));
select sumMap(arrayJoin([([1, 2], [null, 11]), ([1, 2], [null, 22])]));
select sumMapWithOverflow(arrayJoin([([1, 2], [null, 11]), ([1, 2], [null, 22])]));