Clarifications #10981

This commit is contained in:
Alexey Milovidov 2020-05-23 17:28:05 +03:00
parent ac6deaf8ec
commit 73a3394b3d
3 changed files with 39 additions and 22 deletions

View File

@ -103,7 +103,7 @@ class QuantileTDigest
struct RadixSortTraits
{
using Element = Centroid;
using Index = Element;
using Result = Element;
using Key = Value;
using CountType = UInt32;
using KeyBits = UInt32;
@ -115,7 +115,7 @@ class QuantileTDigest
/// The function to get the key from an array element.
static Key & extractKey(Element & elem) { return elem.mean; }
static Index & extractIndex(Element & elem) { return elem; }
static Result & extractResult(Element & elem) { return elem; }
};
/** Adds a centroid `c` to the digest

View File

@ -118,9 +118,9 @@ namespace
struct RadixSortTraits : RadixSortNumTraits<T>
{
using Element = ValueWithIndex<T>;
using Index = size_t;
using Result = size_t;
static T & extractKey(Element & elem) { return elem.value; }
static size_t extractIndex(Element & elem) { return elem.index; }
static size_t extractResult(Element & elem) { return elem.index; }
};
}

View File

@ -70,10 +70,18 @@ struct RadixSortFloatTransform
template <typename TElement>
struct RadixSortFloatTraits
{
using Element = TElement; /// The type of the element. It can be a structure with a key and some other payload. Or just a key.
using Index = Element; /// The index type to store permutation if needed
using Key = Element; /// The key to sort by.
using CountType = uint32_t; /// Type for calculating histograms. In the case of a known small number of elements, it can be less than size_t.
/// The type of the element. It can be a structure with a key and some other payload. Or just a key.
using Element = TElement;
/// The key to sort by.
using Key = Element;
/// Part of the element that you need in the result array.
/// There are cases when elements are sorted by one part but you need other parts in array of results.
using Result = Element;
/// Type for calculating histograms. In the case of a known small number of elements, it can be less than size_t.
using CountType = uint32_t;
/// The type to which the key is transformed to do bit operations. This UInt is the same size as the key.
using KeyBits = std::conditional_t<sizeof(Key) == 8, uint64_t, uint32_t>;
@ -91,8 +99,8 @@ struct RadixSortFloatTraits
/// The function to get the key from an array element.
static Key & extractKey(Element & elem) { return elem; }
/// The function to get the index from an array.
static Index & extractIndex(Element & elem) { return elem; }
/// The function to get the result part from an array element.
static Result & extractResult(Element & elem) { return elem; }
/// Used when fallback to comparison based sorting is needed.
/// TODO: Correct handling of NaNs, NULLs, etc
@ -117,7 +125,7 @@ template <typename TElement>
struct RadixSortUIntTraits
{
using Element = TElement;
using Index = Element;
using Result = Element;
using Key = Element;
using CountType = uint32_t;
using KeyBits = Key;
@ -128,7 +136,7 @@ struct RadixSortUIntTraits
using Allocator = RadixSortMallocAllocator;
static Key & extractKey(Element & elem) { return elem; }
static Index & extractIndex(Element & elem) { return elem; }
static Result & extractResult(Element & elem) { return elem; }
static bool less(Key x, Key y)
{
@ -151,7 +159,7 @@ template <typename TElement>
struct RadixSortIntTraits
{
using Element = TElement;
using Index = Element;
using Result = Element;
using Key = Element;
using CountType = uint32_t;
using KeyBits = std::make_unsigned_t<Key>;
@ -162,7 +170,7 @@ struct RadixSortIntTraits
using Allocator = RadixSortMallocAllocator;
static Key & extractKey(Element & elem) { return elem; }
static Index & extractIndex(Element & elem) { return elem; }
static Result & extractResult(Element & elem) { return elem; }
static bool less(Key x, Key y)
{
@ -183,7 +191,7 @@ struct RadixSort
{
private:
using Element = typename Traits::Element;
using Index = typename Traits::Index;
using Result = typename Traits::Result;
using Key = typename Traits::Key;
using CountType = typename Traits::CountType;
using KeyBits = typename Traits::KeyBits;
@ -312,8 +320,15 @@ private:
}
public:
/// Least significant digit radix sort (stable)
static void executeLSD(Element * arr, size_t size, bool reverse = false, Index * destination = nullptr)
/** Least significant digit radix sort (stable).
*
* This function will sort inplace (modify 'arr')
* but if 'destination' is provided, it will write result directly to destination
* instead of finishing sorting 'arr' at the last step.
* In this case it will fill only Result parts of the Element into destination.
* It is handy to avoid unnecessary data movements.
*/
static void executeLSD(Element * arr, size_t size, bool reverse = false, Result * destination = nullptr)
{
/// If the array is smaller than 256, then it is better to use another algorithm.
@ -378,7 +393,7 @@ public:
if (direct_copy_to_destination)
{
size_t pass = NUM_PASSES - 1;
Index * writer = destination;
Result * writer = destination;
Element * reader = pass % 2 ? swap_buffer : arr;
if (reverse)
@ -388,7 +403,7 @@ public:
size_t pos = getPart(pass, keyToBits(Traits::extractKey(reader[i])));
/// Place the element on the next free position.
writer[size - 1 - (++histograms[pass * HISTOGRAM_SIZE + pos])] = Traits::extractIndex(reader[i]);
writer[size - 1 - (++histograms[pass * HISTOGRAM_SIZE + pos])] = Traits::extractResult(reader[i]);
}
}
else
@ -398,15 +413,17 @@ public:
size_t pos = getPart(pass, keyToBits(Traits::extractKey(reader[i])));
/// Place the element on the next free position.
writer[++histograms[pass * HISTOGRAM_SIZE + pos]] = Traits::extractIndex(reader[i]);
writer[++histograms[pass * HISTOGRAM_SIZE + pos]] = Traits::extractResult(reader[i]);
}
}
} else if (NUM_PASSES % 2)
}
else if (NUM_PASSES % 2)
{
/// If the number of passes is odd, the result array is in a temporary buffer. Copy it to the place of the original array.
/// NOTE Sometimes it will be more optimal to provide non-destructive interface, that will not modify original array.
memcpy(arr, swap_buffer, size * sizeof(Element));
} else if (reverse)
}
else if (reverse)
{
std::reverse(arr, arr + size);
}