Merge pull request #58959 from Algunenano/fix_sum_map

Fix sumMapFiltered with NaN values
This commit is contained in:
Alexey Milovidov 2024-01-25 23:50:28 +01:00 committed by GitHub
commit aedffeaab0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 126 additions and 101 deletions

View File

@ -1,15 +1,16 @@
#pragma once
#include <cmath>
#include <Columns/ColumnVectorHelper.h>
#include <Columns/IColumn.h>
#include <Columns/IColumnImpl.h>
#include <Columns/ColumnVectorHelper.h>
#include <base/unaligned.h>
#include <Core/CompareHelper.h>
#include <Core/Field.h>
#include <Common/assert_cast.h>
#include <Common/TargetSpecific.h>
#include <Core/TypeId.h>
#include <base/TypeName.h>
#include <base/unaligned.h>
#include <Common/TargetSpecific.h>
#include <Common/assert_cast.h>
#include "config.h"
@ -26,91 +27,6 @@ namespace ErrorCodes
}
/** Stuff for comparing numbers.
* Integer values are compared as usual.
* Floating-point numbers are compared this way that NaNs always end up at the end
* (if you don't do this, the sort would not work at all).
*/
template <class T, class U = T>
struct CompareHelper
{
static constexpr bool less(T a, U b, int /*nan_direction_hint*/) { return a < b; }
static constexpr bool greater(T a, U b, int /*nan_direction_hint*/) { return a > b; }
static constexpr bool equals(T a, U b, int /*nan_direction_hint*/) { return a == b; }
/** Compares two numbers. Returns a number less than zero, equal to zero, or greater than zero if a < b, a == b, a > b, respectively.
* If one of the values is NaN, then
* - if nan_direction_hint == -1 - NaN are considered less than all numbers;
* - if nan_direction_hint == 1 - NaN are considered to be larger than all numbers;
* Essentially: nan_direction_hint == -1 says that the comparison is for sorting in descending order.
*/
static constexpr int compare(T a, U b, int /*nan_direction_hint*/)
{
return a > b ? 1 : (a < b ? -1 : 0);
}
};
template <class T>
struct FloatCompareHelper
{
static constexpr bool less(T a, T b, int nan_direction_hint)
{
const bool isnan_a = std::isnan(a);
const bool isnan_b = std::isnan(b);
if (isnan_a && isnan_b)
return false;
if (isnan_a)
return nan_direction_hint < 0;
if (isnan_b)
return nan_direction_hint > 0;
return a < b;
}
static constexpr bool greater(T a, T b, int nan_direction_hint)
{
const bool isnan_a = std::isnan(a);
const bool isnan_b = std::isnan(b);
if (isnan_a && isnan_b)
return false;
if (isnan_a)
return nan_direction_hint > 0;
if (isnan_b)
return nan_direction_hint < 0;
return a > b;
}
static constexpr bool equals(T a, T b, int nan_direction_hint)
{
return compare(a, b, nan_direction_hint) == 0;
}
static constexpr int compare(T a, T b, int nan_direction_hint)
{
const bool isnan_a = std::isnan(a);
const bool isnan_b = std::isnan(b);
if (unlikely(isnan_a || isnan_b))
{
if (isnan_a && isnan_b)
return 0;
return isnan_a
? nan_direction_hint
: -nan_direction_hint;
}
return (T(0) < (a - b)) - ((a - b) < T(0));
}
};
template <typename U> struct CompareHelper<Float32, U> : public FloatCompareHelper<Float32> {};
template <typename U> struct CompareHelper<Float64, U> : public FloatCompareHelper<Float64> {};
/** A template for columns that use a simple array to store.
*/
template <typename T>

93
src/Core/CompareHelper.h Normal file
View File

@ -0,0 +1,93 @@
#pragma once
#include <base/defines.h>
#include <base/types.h>
#include <cmath>
namespace DB
{
/** Stuff for comparing numbers.
* Integer values are compared as usual.
* Floating-point numbers are compared this way that NaNs always end up at the end
* (if you don't do this, the sort would not work at all).
*/
template <class T, class U = T>
struct CompareHelper
{
static constexpr bool less(T a, U b, int /*nan_direction_hint*/) { return a < b; }
static constexpr bool greater(T a, U b, int /*nan_direction_hint*/) { return a > b; }
static constexpr bool equals(T a, U b, int /*nan_direction_hint*/) { return a == b; }
/** Compares two numbers. Returns a number less than zero, equal to zero, or greater than zero if a < b, a == b, a > b, respectively.
* If one of the values is NaN, then
* - if nan_direction_hint == -1 - NaN are considered less than all numbers;
* - if nan_direction_hint == 1 - NaN are considered to be larger than all numbers;
* Essentially: nan_direction_hint == -1 says that the comparison is for sorting in descending order.
*/
static constexpr int compare(T a, U b, int /*nan_direction_hint*/) { return a > b ? 1 : (a < b ? -1 : 0); }
};
template <class T>
struct FloatCompareHelper
{
static constexpr bool less(T a, T b, int nan_direction_hint)
{
const bool isnan_a = std::isnan(a);
const bool isnan_b = std::isnan(b);
if (isnan_a && isnan_b)
return false;
if (isnan_a)
return nan_direction_hint < 0;
if (isnan_b)
return nan_direction_hint > 0;
return a < b;
}
static constexpr bool greater(T a, T b, int nan_direction_hint)
{
const bool isnan_a = std::isnan(a);
const bool isnan_b = std::isnan(b);
if (isnan_a && isnan_b)
return false;
if (isnan_a)
return nan_direction_hint > 0;
if (isnan_b)
return nan_direction_hint < 0;
return a > b;
}
static constexpr bool equals(T a, T b, int nan_direction_hint) { return compare(a, b, nan_direction_hint) == 0; }
static constexpr int compare(T a, T b, int nan_direction_hint)
{
const bool isnan_a = std::isnan(a);
const bool isnan_b = std::isnan(b);
if (unlikely(isnan_a || isnan_b))
{
if (isnan_a && isnan_b)
return 0;
return isnan_a ? nan_direction_hint : -nan_direction_hint;
}
return (T(0) < (a - b)) - ((a - b) < T(0));
}
};
template <typename U>
struct CompareHelper<Float32, U> : public FloatCompareHelper<Float32>
{
};
template <typename U>
struct CompareHelper<Float64, U> : public FloatCompareHelper<Float64>
{
};
}

View File

@ -7,14 +7,15 @@
#include <type_traits>
#include <functional>
#include <Common/Exception.h>
#include <Common/AllocatorWithMemoryTracking.h>
#include <Core/Types.h>
#include <Core/Defines.h>
#include <Core/CompareHelper.h>
#include <Core/DecimalFunctions.h>
#include <Core/Defines.h>
#include <Core/Types.h>
#include <Core/UUID.h>
#include <base/IPv4andIPv6.h>
#include <base/DayNum.h>
#include <base/IPv4andIPv6.h>
#include <Common/AllocatorWithMemoryTracking.h>
#include <Common/Exception.h>
namespace DB
{
@ -305,6 +306,7 @@ static constexpr auto DBMS_MIN_FIELD_SIZE = 32;
*/
class Field
{
static constexpr int nan_direction_hint = 1; // When comparing Floats NaN are considered to be larger than all numbers
public:
struct Types
{
@ -508,7 +510,8 @@ public:
case Types::UUID: return get<UUID>() < rhs.get<UUID>();
case Types::IPv4: return get<IPv4>() < rhs.get<IPv4>();
case Types::IPv6: return get<IPv6>() < rhs.get<IPv6>();
case Types::Float64: return get<Float64>() < rhs.get<Float64>();
case Types::Float64:
return FloatCompareHelper<Float64>::less(get<Float64>(), rhs.get<Float64>(), nan_direction_hint);
case Types::String: return get<String>() < rhs.get<String>();
case Types::Array: return get<Array>() < rhs.get<Array>();
case Types::Tuple: return get<Tuple>() < rhs.get<Tuple>();
@ -550,7 +553,13 @@ public:
case Types::UUID: return get<UUID>().toUnderType() <= rhs.get<UUID>().toUnderType();
case Types::IPv4: return get<IPv4>() <= rhs.get<IPv4>();
case Types::IPv6: return get<IPv6>() <= rhs.get<IPv6>();
case Types::Float64: return get<Float64>() <= rhs.get<Float64>();
case Types::Float64:
{
Float64 f1 = get<Float64>();
Float64 f2 = get<Float64>();
return FloatCompareHelper<Float64>::less(f1, f2, nan_direction_hint)
|| FloatCompareHelper<Float64>::equals(f1, f2, nan_direction_hint);
}
case Types::String: return get<String>() <= rhs.get<String>();
case Types::Array: return get<Array>() <= rhs.get<Array>();
case Types::Tuple: return get<Tuple>() <= rhs.get<Tuple>();
@ -586,10 +595,7 @@ public:
case Types::UInt64: return get<UInt64>() == rhs.get<UInt64>();
case Types::Int64: return get<Int64>() == rhs.get<Int64>();
case Types::Float64:
{
// Compare as UInt64 so that NaNs compare as equal.
return std::bit_cast<UInt64>(get<Float64>()) == std::bit_cast<UInt64>(rhs.get<Float64>());
}
return FloatCompareHelper<Float64>::equals(get<Float64>(), rhs.get<Float64>(), nan_direction_hint);
case Types::UUID: return get<UUID>() == rhs.get<UUID>();
case Types::IPv4: return get<IPv4>() == rhs.get<IPv4>();
case Types::IPv6: return get<IPv6>() == rhs.get<IPv6>();

View File

@ -335,7 +335,11 @@ public:
&& column_array->getOffsets() != typeid_cast<const ColumnArray::ColumnOffsets &>(*offsets_column).getData())
throw Exception(
ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH,
"Arrays passed to {} must have equal size", getName());
"Arrays passed to {} must have equal size. Argument {} has size {}, but expected {}",
getName(),
i,
column_array->getOffsets().size(),
typeid_cast<const ColumnArray::ColumnOffsets &>(*offsets_column).getData().size());
}
const auto * column_tuple = checkAndGetColumn<ColumnTuple>(&column_array->getData());

View File

@ -0,0 +1,2 @@
([6.7],[3])
([1,4,5,6.7,nan],[2.3,5,1,3,inf])

View File

@ -0,0 +1,4 @@
SELECT sumMapFiltered([6.7])([x], [y])
FROM values('x Float64, y Float64', (0, 1), (1, 2.3), (nan, inf), (6.7, 3), (4, 4), (5, 1));
SELECT sumMap([x],[y]) FROM values('x Float64, y Float64', (4, 1), (1, 2.3), (nan,inf), (6.7,3), (4,4), (5, 1));