Implement findExtremeMinIndex / findExtremeMaxIndex

This commit is contained in:
Raúl Marín 2024-01-17 15:51:21 +01:00
parent 6daaf3dc72
commit 849ac1fe99
7 changed files with 112 additions and 29 deletions

View File

@ -421,15 +421,7 @@ std::optional<size_t> SingleValueDataFixed<T>::getSmallestIndex(const IColumn &
const auto & vec = assert_cast<const ColVecType &>(column);
if constexpr (has_find_extreme_implementation<T>)
{
std::optional<T> opt = findExtremeMin(vec.getData().data(), row_begin, row_end);
if (!opt || (has() && value <= opt))
return std::nullopt;
/// TODO: Implement findExtremeMinIndex to do the lookup properly (with SIMD and batching)
for (size_t i = row_begin; i < row_end; i++)
if (vec.getData()[i] == *opt)
return i;
return row_end;
return findExtremeMinIndex(vec.getData().data(), row_begin, row_end);
}
else
{
@ -450,15 +442,7 @@ std::optional<size_t> SingleValueDataFixed<T>::getGreatestIndex(const IColumn &
const auto & vec = assert_cast<const ColVecType &>(column);
if constexpr (has_find_extreme_implementation<T>)
{
std::optional<T> opt = findExtremeMax(vec.getData().data(), row_begin, row_end);
if (!opt || (has() && value >= opt))
return std::nullopt;
/// TODO: Implement findExtremeMaxIndex to do the lookup properly (with SIMD and batching)
for (size_t i = row_begin; i < row_end; i++)
if (vec.getData()[i] == *opt)
return i;
return row_end;
return findExtremeMaxIndex(vec.getData().data(), row_begin, row_end);
}
else
{

View File

@ -57,8 +57,7 @@ struct SingleValueDataBase
virtual void setGreatestNotNullIf(const IColumn &, const UInt8 * __restrict, const UInt8 * __restrict, size_t, size_t, Arena *);
/// Given a column returns the index of the smallest or greatest value in it
/// Doesn't return anything if the column is empty. In some cases (SingleValueDataFixed<T>) it will also return
/// empty if the stored value is already the smallest/greatest
/// Doesn't return anything if the column is empty
/// There are used to implement argMin / argMax
virtual std::optional<size_t> getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end);
virtual std::optional<size_t> getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end);

View File

@ -19,6 +19,7 @@
#include <Common/TargetSpecific.h>
#include <Common/WeakHash.h>
#include <Common/assert_cast.h>
#include <Common/findExtreme.h>
#include <Common/iota.h>
#include <bit>
@ -247,6 +248,26 @@ void ColumnVector<T>::getPermutation(IColumn::PermutationSortDirection direction
iota(res.data(), data_size, IColumn::Permutation::value_type(0));
if constexpr (has_find_extreme_implementation<T> && !std::is_floating_point_v<T>)
{
/// Disabled for:floating point
/// * floating point: We don't deal with nan_direction_hint
/// * stability::Stable: We might return any value, not the first
if ((limit == 1) && (stability == IColumn::PermutationSortStability::Unstable))
{
std::optional<size_t> index;
if (direction == IColumn::PermutationSortDirection::Ascending)
index = findExtremeMinIndex(data.data(), 0, data.size());
else
index = findExtremeMaxIndex(data.data(), 0, data.size());
if (index)
{
res.data()[0] = *index;
return;
}
}
}
if constexpr (is_arithmetic_v<T> && !is_big_int_v<T>)
{
if (!limit)

View File

@ -133,14 +133,54 @@ std::optional<T> findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __rest
return findExtreme<T, MaxComparator<T>, false, false>(ptr, condition_map, start, end);
}
template <has_find_extreme_implementation T>
std::optional<size_t> findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end)
{
/// This is implemented based on findNumericExtreme and not the other way around (or independently) because getting
/// the MIN or MAX value of an array is possible with SIMD, but getting the index isn't.
/// So what we do is use SIMD to find the lowest value and then iterate again over the array to find its position
std::optional<T> opt = findExtremeMin(ptr, start, end);
if (!opt)
return std::nullopt;
/// Some minimal heuristics for the case the input is sorted
if (*opt == ptr[start])
return {start};
for (size_t i = end - 1; i != start + 1; i--)
if (ptr[i] == *opt)
return {i};
return std::nullopt;
}
template <has_find_extreme_implementation T>
std::optional<size_t> findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end)
{
std::optional<T> opt = findExtremeMax(ptr, start, end);
if (!opt)
return std::nullopt;
/// Some minimal heuristics for the case the input is sorted
if (*opt == ptr[start])
return {start};
for (size_t i = end - 1; i != start + 1; i--)
if (ptr[i] == *opt)
return {i};
return std::nullopt;
}
#define INSTANTIATION(T) \
template std::optional<T> findExtremeMin(const T * __restrict ptr, size_t start, size_t end); \
template std::optional<T> findExtremeMinNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMinIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMinNotNull( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMinIf( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMax(const T * __restrict ptr, size_t start, size_t end); \
template std::optional<T> findExtremeMaxNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
template std::optional<T> findExtremeMaxNotNull( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<T> findExtremeMaxIf( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
template std::optional<size_t> findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end); \
template std::optional<size_t> findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end);
FOR_BASIC_NUMERIC_TYPES(INSTANTIATION)
#undef INSTANTIATION

View File

@ -31,15 +31,27 @@ std::optional<T> findExtremeMaxNotNull(const T * __restrict ptr, const UInt8 * _
template <has_find_extreme_implementation T>
std::optional<T> findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
template <has_find_extreme_implementation T>
std::optional<size_t> findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end);
template <has_find_extreme_implementation T>
std::optional<size_t> findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end);
#define EXTERN_INSTANTIATION(T) \
extern template std::optional<T> findExtremeMin(const T * __restrict ptr, size_t start, size_t end); \
extern template std::optional<T> findExtremeMinNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMinIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMinNotNull( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMinIf( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMax(const T * __restrict ptr, size_t start, size_t end); \
extern template std::optional<T> findExtremeMaxNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end);
extern template std::optional<T> findExtremeMaxNotNull( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<T> findExtremeMaxIf( \
const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \
extern template std::optional<size_t> findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end); \
extern template std::optional<size_t> findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end);
FOR_BASIC_NUMERIC_TYPES(EXTERN_INSTANTIATION)
FOR_BASIC_NUMERIC_TYPES(EXTERN_INSTANTIATION)
#undef EXTERN_INSTANTIATION
}

View File

@ -0,0 +1,24 @@
<test>
<substitutions>
<substitution>
<name>group_scale</name>
<values>
<value>1000000</value>
</values>
</substitution>
</substitutions>
<query>select argMin(Title, EventTime) from hits_100m_single where Title != '' group by intHash32(UserID) % {group_scale} FORMAT Null</query>
<query>select argMinIf(Title, EventTime, Title != '') from hits_100m_single group by intHash32(UserID) % {group_scale} FORMAT Null</query>
<query>select argMinIf(Title::Nullable(String), EventTime::Nullable(DateTime), Title::Nullable(String) != '') from hits_100m_single group by intHash32(UserID) % {group_scale} FORMAT Null</query>
<query>select argMin(RegionID, EventTime) from hits_100m_single where Title != '' group by intHash32(UserID) % {group_scale} FORMAT Null</query>
<query>select argMin((Title, RegionID), EventTime) from hits_100m_single where Title != '' group by intHash32(UserID) % {group_scale} FORMAT Null</query>
<query>select argMinIf(Title, EventTime, Title != '') from hits_100m_single group by intHash32(UserID) % {group_scale} FORMAT Null</query>
<query>select argMax(WatchID, Age) from hits_100m_single FORMAT Null</query>
<query>select argMax(WatchID, Age::Nullable(UInt8)) from hits_100m_single FORMAT Null</query>
<query>select argMax(WatchID, (EventDate, EventTime)) from hits_100m_single where Title != '' group by intHash32(UserID) % {group_scale} FORMAT Null</query>
<query>select argMax(MobilePhone, MobilePhoneModel) from hits_100m_single</query>
</test>

View File

@ -1,4 +1,5 @@
<test>
<query>SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 1 FORMAT Null</query>
<query>SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 10 FORMAT Null</query>
<query>SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 100 FORMAT Null</query>
<query>SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 1500 FORMAT Null</query>
@ -7,6 +8,7 @@
<query>SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 10000 FORMAT Null</query>
<query>SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 65535 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(500000000) ORDER BY n LIMIT 1 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(500000000) ORDER BY n LIMIT 10 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n LIMIT 100 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n LIMIT 1500 FORMAT Null</query>
@ -15,6 +17,7 @@
<query>SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n LIMIT 10000 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(100000000) ORDER BY n LIMIT 65535 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n, n + 1, n + 2 LIMIT 1 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n, n + 1, n + 2 LIMIT 10 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n, n + 1, n + 2 LIMIT 100 FORMAT Null</query>
<query>SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n, n + 1, n + 2 LIMIT 1500 FORMAT Null</query>