Code cleanups after #4439

This commit is contained in:
Alexey Milovidov 2019-04-25 04:16:26 +03:00
parent 445f51c01e
commit 661c840fbe
3 changed files with 143 additions and 204 deletions

View File

@ -68,12 +68,33 @@ struct ColumnVector<T>::greater
bool operator()(size_t lhs, size_t rhs) const { return CompareHelper<T>::greater(parent.data[lhs], parent.data[rhs], nan_direction_hint); }
};
namespace
{
template <typename T>
struct ValueWithIndex
{
T value;
UInt32 index;
};
template <typename T>
struct RadixSortTraits : RadixSortNumTraits<T>
{
using Element = ValueWithIndex<T>;
static T & extractKey(Element & elem) { return elem.value; }
};
}
template <typename T>
void ColumnVector<T>::getPermutation(bool reverse, size_t limit, int nan_direction_hint, IColumn::Permutation & res) const
{
size_t s = data.size();
res.resize(s);
if (s == 0)
return;
if (limit >= s)
limit = 0;
@ -89,23 +110,68 @@ void ColumnVector<T>::getPermutation(bool reverse, size_t limit, int nan_directi
}
else
{
if constexpr ((std::is_signed_v<T> || std::is_unsigned_v<T>) && !std::is_same_v<T, UInt128>)
/// A case for radix sort
if constexpr (std::is_arithmetic_v<T> && !std::is_same_v<T, UInt128>)
{
PaddedPODArray<std::pair<T, size_t>> pairs(s);
for (size_t i = 0; i < s; ++i)
/// Thresholds on size. Lower threshold is arbitrary. Upper threshold is chosen by the type for histogram counters.
if (s >= 256 && s <= std::numeric_limits<UInt32>::max())
{
PaddedPODArray<ValueWithIndex<T>> pairs(s);
for (UInt32 i = 0; i < s; ++i)
pairs[i] = {data[i], i};
radixSort(pairs.data(), s, nan_direction_hint);
RadixSort<RadixSortTraits<T>>::execute(pairs.data(), s);
/// Radix sort treats all NaNs to be greater than all numbers.
/// If the user needs the opposite, we must move them accordingly.
size_t nans_to_move = 0;
if (std::is_floating_point_v<T> && nan_direction_hint < 0)
{
for (ssize_t i = s - 1; i >= 0; --i)
{
if (isNaN(pairs[i].value))
++nans_to_move;
else
break;
}
}
if (reverse)
for (size_t i = 0; i < s; ++i)
res[s - 1 - i] = pairs[i].second;
else
for (size_t i = 0; i < s; ++i)
res[i] = pairs[i].second;
{
if (nans_to_move)
{
for (size_t i = 0; i < s - nans_to_move; ++i)
res[i] = pairs[s - nans_to_move - 1 - i].index;
for (size_t i = s - nans_to_move; i < s; ++i)
res[i] = pairs[s - 1 - (i - (s - nans_to_move))].index;
}
else
{
for (size_t i = 0; i < s; ++i)
res[s - 1 - i] = pairs[i].index;
}
}
else
{
if (nans_to_move)
{
for (size_t i = 0; i < nans_to_move; ++i)
res[i] = pairs[i + s - nans_to_move].index;
for (size_t i = nans_to_move; i < s; ++i)
res[i] = pairs[i - nans_to_move].index;
}
else
{
for (size_t i = 0; i < s; ++i)
res[i] = pairs[i].index;
}
}
return;
}
}
/// Default sorting algorithm.
for (size_t i = 0; i < s; ++i)
res[i] = i;
@ -115,7 +181,7 @@ void ColumnVector<T>::getPermutation(bool reverse, size_t limit, int nan_directi
pdqsort(res.begin(), res.end(), less(*this, nan_direction_hint));
}
}
}
template <typename T>
const char * ColumnVector<T>::getFamilyName() const

View File

@ -10,7 +10,6 @@
#include <cstdlib>
#include <cstdint>
#include <type_traits>
#include <vector>
#include <ext/bit_cast.h>
#include <Core/Types.h>
@ -68,15 +67,15 @@ struct RadixSortFloatTransform
};
template <typename _Element, typename _Key = _Element>
template <typename TElement>
struct RadixSortFloatTraits
{
using Element = _Element; /// The type of the element. It can be a structure with a key and some other payload. Or just a key.
using Key = _Key; /// The key to sort.
using Element = TElement; /// The type of the element. It can be a structure with a key and some other payload. Or just a key.
using Key = Element; /// The key to sort by.
using CountType = uint32_t; /// Type for calculating histograms. In the case of a known small number of elements, it can be less than size_t.
/// The type to which the key is transformed to do bit operations. This UInt is the same size as the key.
using KeyBits = std::conditional_t<sizeof(_Key) == 8, uint64_t, uint32_t>;
using KeyBits = std::conditional_t<sizeof(Key) == 8, uint64_t, uint32_t>;
static constexpr size_t PART_SIZE_BITS = 8; /// With what pieces of the key, in bits, to do one pass - reshuffle of the array.
@ -89,30 +88,7 @@ struct RadixSortFloatTraits
using Allocator = RadixSortMallocAllocator;
/// The function to get the key from an array element.
static Key & extractKey(Element & elem)
{
if constexpr (std::is_same_v<Element, Key>)
return elem;
else
return *reinterpret_cast<Key *>(&elem);
}
};
template <typename Float>
struct RadixSortPairFloatKeyTraits
{
using Element = std::pair<Float, size_t>;
using Key = Float;
using CountType = uint32_t;
using KeyBits = std::conditional_t<sizeof(Float) == 8, uint64_t, uint32_t>;
static constexpr size_t PART_SIZE_BITS = 8;
using Transform = RadixSortFloatTransform<KeyBits>;
using Allocator = RadixSortMallocAllocator;
/// The function to get the key from an array element.
static Key & extractKey(Element & elem) { return elem.first; }
static Key & extractKey(Element & elem) { return elem; }
};
template <typename KeyBits>
@ -135,92 +111,49 @@ struct RadixSortSignedTransform
};
template <typename _Element, typename _Key = _Element>
template <typename TElement>
struct RadixSortUIntTraits
{
using Element = _Element;
using Key = _Key;
using Element = TElement;
using Key = Element;
using CountType = uint32_t;
using KeyBits = _Key;
using KeyBits = Key;
static constexpr size_t PART_SIZE_BITS = 8;
using Transform = RadixSortIdentityTransform<KeyBits>;
using Allocator = RadixSortMallocAllocator;
/// The function to get the key from an array element.
static Key & extractKey(Element & elem)
{
if constexpr (std::is_same_v<Element, Key>)
return elem;
else
return *reinterpret_cast<Key *>(&elem);
}
static Key & extractKey(Element & elem) { return elem; }
};
template <typename UInt>
struct RadixSortPairUIntKeyTraits
{
using Element = std::pair<UInt, size_t>;
using Key = UInt;
using CountType = uint32_t;
using KeyBits = UInt;
static constexpr size_t PART_SIZE_BITS = 8;
using Transform = RadixSortIdentityTransform<KeyBits>;
using Allocator = RadixSortMallocAllocator;
/// The function to get the key from an array element.
static Key & extractKey(Element & elem) { return elem.first; }
};
template <typename _Element, typename _Key = _Element>
template <typename TElement>
struct RadixSortIntTraits
{
using Element = _Element;
using Key = _Key;
using Element = TElement;
using Key = Element;
using CountType = uint32_t;
using KeyBits = std::make_unsigned_t<_Key>;
using KeyBits = std::make_unsigned_t<Key>;
static constexpr size_t PART_SIZE_BITS = 8;
using Transform = RadixSortSignedTransform<KeyBits>;
using Allocator = RadixSortMallocAllocator;
/// The function to get the key from an array element.
static Key & extractKey(Element & elem)
{
if constexpr (std::is_same_v<Element, Key>)
return elem;
else
return *reinterpret_cast<Key *>(&elem);
}
};
template <typename Int>
struct RadixSortPairIntKeyTraits
{
using Element = std::pair<Int, size_t>;
using Key = Int;
using CountType = uint32_t;
using KeyBits = std::make_unsigned_t<Int>;
static constexpr size_t PART_SIZE_BITS = 8;
using Transform = RadixSortSignedTransform<KeyBits>;
using Allocator = RadixSortMallocAllocator;
/// The function to get the key from an array element.
static Key & extractKey(Element & elem) { return elem.first; }
static Key & extractKey(Element & elem) { return elem; }
};
// Allow std::pair copying
#if defined(__GNUC__) && !defined(__clang__) && (__GNUC__ >= 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wclass-memaccess"
#endif
template <typename T>
using RadixSortNumTraits =
std::conditional_t<std::is_integral_v<T>,
std::conditional_t<std::is_unsigned_v<T>,
RadixSortUIntTraits<T>,
RadixSortIntTraits<T>>,
RadixSortFloatTraits<T>>;
template <typename Traits>
struct RadixSort
{
@ -268,8 +201,8 @@ public:
if (!Traits::Transform::transform_is_simple)
Traits::extractKey(arr[i]) = bitsToKey(Traits::Transform::forward(keyToBits(Traits::extractKey(arr[i]))));
for (size_t j = 0; j < NUM_PASSES; ++j)
++histograms[j * HISTOGRAM_SIZE + getPart(j, keyToBits(Traits::extractKey(arr[i])))];
for (size_t pass = 0; pass < NUM_PASSES; ++pass)
++histograms[pass * HISTOGRAM_SIZE + getPart(pass, keyToBits(Traits::extractKey(arr[i])))];
}
{
@ -278,31 +211,31 @@ public:
for (size_t i = 0; i < HISTOGRAM_SIZE; ++i)
{
for (size_t j = 0; j < NUM_PASSES; ++j)
for (size_t pass = 0; pass < NUM_PASSES; ++pass)
{
size_t tmp = histograms[j * HISTOGRAM_SIZE + i] + sums[j];
histograms[j * HISTOGRAM_SIZE + i] = sums[j] - 1;
sums[j] = tmp;
size_t tmp = histograms[pass * HISTOGRAM_SIZE + i] + sums[pass];
histograms[pass * HISTOGRAM_SIZE + i] = sums[pass] - 1;
sums[pass] = tmp;
}
}
}
/// Move the elements in the order starting from the least bit piece, and then do a few passes on the number of pieces.
for (size_t j = 0; j < NUM_PASSES; ++j)
for (size_t pass = 0; pass < NUM_PASSES; ++pass)
{
Element * writer = j % 2 ? arr : swap_buffer;
Element * reader = j % 2 ? swap_buffer : arr;
Element * writer = pass % 2 ? arr : swap_buffer;
Element * reader = pass % 2 ? swap_buffer : arr;
for (size_t i = 0; i < size; ++i)
{
size_t pos = getPart(j, keyToBits(Traits::extractKey(reader[i])));
size_t pos = getPart(pass, keyToBits(Traits::extractKey(reader[i])));
/// Place the element on the next free position.
auto & dest = writer[++histograms[j * HISTOGRAM_SIZE + pos]];
auto & dest = writer[++histograms[pass * HISTOGRAM_SIZE + pos]];
dest = reader[i];
/// On the last pass, we do the reverse transformation.
if (!Traits::Transform::transform_is_simple && j == NUM_PASSES - 1)
if (!Traits::Transform::transform_is_simple && pass == NUM_PASSES - 1)
Traits::extractKey(dest) = bitsToKey(Traits::Transform::backward(keyToBits(Traits::extractKey(reader[i]))));
}
}
@ -314,80 +247,14 @@ public:
allocator.deallocate(swap_buffer, size * sizeof(Element));
}
// For floating point types
// Radix sort sometimes incorrectly handles NaNs
// Will move them to the right place
static void fixNanOrder(Element * arr, size_t size, int nan_direction_hint)
{
if (nan_direction_hint < 0)
{
size_t nans_count = std::count_if(arr, arr + size, [](Element d) {return std::isnan(Traits::extractKey(d));});
std::vector<Element> nans(nans_count);
std::copy(arr + size - nans_count, arr + size, nans.data());
std::copy_backward(arr, arr + size - nans_count, arr + size);
std::copy(nans.data(), nans.data() + nans_count, arr);
}
}
};
#if defined(__GNUC__) && !defined(__clang__) && (__GNUC__ >= 8)
#pragma GCC diagnostic pop
#endif
/// Helper functions for numeric types.
/// Use RadixSort with custom traits for complex types instead.
template <typename T>
std::enable_if_t<std::is_unsigned_v<T> && std::is_integral_v<T>, void>
radixSort(T * arr, size_t size, int /*nan_direction_hint*/=1)
void radixSort(T * arr, size_t size)
{
RadixSort<RadixSortUIntTraits<T>>::execute(arr, size);
}
template <typename T>
std::enable_if_t<std::is_signed_v<T> && std::is_integral_v<T>, void>
radixSort(T * arr, size_t size, int /*nan_direction_hint*/=1)
{
RadixSort<RadixSortIntTraits<T>>::execute(arr, size);
}
template <typename T>
std::enable_if_t<std::is_floating_point_v<T>, void>
radixSort(T * arr, size_t size, int nan_direction_hint=1)
{
RadixSort<RadixSortFloatTraits<T>>::execute(arr, size);
RadixSort<RadixSortFloatTraits<T>>::fixNanOrder(arr, size, nan_direction_hint);
}
template <typename _Element, typename _Key>
std::enable_if_t<std::is_integral_v<_Key>, void>
radixSort(_Element * arr, size_t size)
{
return RadixSort<RadixSortUIntTraits<_Element, _Key>>::execute(arr, size);
}
template <typename _Element, typename _Key>
std::enable_if_t<std::is_floating_point_v<_Key>, void>
radixSort(_Element * arr, size_t size)
{
return RadixSort<RadixSortFloatTraits<_Element, _Key>>::execute(arr, size);
}
template <typename T>
std::enable_if_t<std::is_unsigned_v<T> && !std::is_floating_point_v<T>, void>
radixSort(std::pair<T, size_t> * arr, size_t size, int /*nan_direction_hint*/=1)
{
RadixSort<RadixSortPairUIntKeyTraits<T>>::execute(arr, size);
}
template <typename T>
std::enable_if_t<std::is_signed_v<T> && !std::is_floating_point_v<T>, void>
radixSort(std::pair<T, size_t> * arr, size_t size, int /*nan_direction_hint*/=1)
{
RadixSort<RadixSortPairIntKeyTraits<T>>::execute(arr, size);
}
template <typename T>
std::enable_if_t<std::is_floating_point_v<T>, void>
radixSort(std::pair<T, size_t> * arr, size_t size, int nan_direction_hint=1)
{
RadixSort<RadixSortPairFloatKeyTraits<T>>::execute(arr, size);
RadixSort<RadixSortPairFloatKeyTraits<T>>::fixNanOrder(arr, size, nan_direction_hint);
RadixSort<RadixSortNumTraits<T>>::execute(arr, size);
}

View File

@ -40,11 +40,11 @@ struct RowRefList : RowRef
* references that can be returned by the lookup methods
*/
template <typename _Entry, typename _Key>
template <typename TEntry, typename TKey>
class SortedLookupVector
{
public:
using Base = std::vector<_Entry>;
using Base = std::vector<TEntry>;
// First stage, insertions into the vector
template <typename U, typename ... TAllocatorParams>
@ -55,7 +55,7 @@ public:
}
// Transition into second stage, ensures that the vector is sorted
typename Base::const_iterator upper_bound(const _Entry & k)
typename Base::const_iterator upper_bound(const TEntry & k)
{
sort();
return std::upper_bound(array.cbegin(), array.cend(), k);
@ -70,6 +70,12 @@ private:
Base array;
mutable std::mutex lock;
struct RadixSortTraits : RadixSortNumTraits<TKey>
{
using Element = TEntry;
static TKey & extractKey(Element & elem) { return elem.asof_value; }
};
// Double checked locking with SC atomics works in C++
// https://preshing.com/20130930/double-checked-locking-is-fixed-in-cpp11/
// The first thread that calls one of the lookup methods sorts the data
@ -81,15 +87,15 @@ private:
{
std::lock_guard<std::mutex> l(lock);
if (!sorted.load(std::memory_order_relaxed))
{
/// TODO: It has been tested only for UInt32 yet. It needs to check UInt64, Float32/64.
if constexpr (std::is_same_v<_Key, UInt32>)
{
if (!array.empty())
radixSort<_Entry, _Key>(&array[0], array.size());
}
{
/// TODO: It has been tested only for UInt32 yet. It needs to check UInt64, Float32/64.
if constexpr (std::is_same_v<TKey, UInt32>)
RadixSort<RadixSortTraits>::execute(&array[0], array.size());
else
std::sort(array.begin(), array.end());
}
sorted.store(true, std::memory_order_release);
}