mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-10 01:25:21 +00:00
Merge pull request #59003 from kitaisreal/revert-57519-revert-53562-group_sorted_array_function
Revert "Revert "Add new aggregation function groupArraySorted()""
This commit is contained in:
commit
7e722b52a9
@ -64,19 +64,14 @@ using ComparatorWrapper = Comparator;
|
||||
|
||||
#include <miniselect/floyd_rivest_select.h>
|
||||
|
||||
template <typename RandomIt>
|
||||
void nth_element(RandomIt first, RandomIt nth, RandomIt last)
|
||||
template <typename RandomIt, typename Compare>
|
||||
void nth_element(RandomIt first, RandomIt nth, RandomIt last, Compare compare)
|
||||
{
|
||||
using value_type = typename std::iterator_traits<RandomIt>::value_type;
|
||||
using comparator = std::less<value_type>;
|
||||
|
||||
comparator compare;
|
||||
ComparatorWrapper<comparator> compare_wrapper = compare;
|
||||
|
||||
#ifndef NDEBUG
|
||||
::shuffle(first, last);
|
||||
#endif
|
||||
|
||||
ComparatorWrapper<Compare> compare_wrapper = compare;
|
||||
::miniselect::floyd_rivest_select(first, nth, last, compare_wrapper);
|
||||
|
||||
#ifndef NDEBUG
|
||||
@ -87,6 +82,15 @@ void nth_element(RandomIt first, RandomIt nth, RandomIt last)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename RandomIt>
|
||||
void nth_element(RandomIt first, RandomIt nth, RandomIt last)
|
||||
{
|
||||
using value_type = typename std::iterator_traits<RandomIt>::value_type;
|
||||
using comparator = std::less<value_type>;
|
||||
|
||||
::nth_element(first, nth, last, comparator());
|
||||
}
|
||||
|
||||
template <typename RandomIt, typename Compare>
|
||||
void partial_sort(RandomIt first, RandomIt middle, RandomIt last, Compare compare)
|
||||
{
|
||||
|
@ -0,0 +1,48 @@
|
||||
---
|
||||
toc_priority: 112
|
||||
---
|
||||
|
||||
# groupArraySorted {#groupArraySorted}
|
||||
|
||||
Returns an array with the first N items in ascending order.
|
||||
|
||||
``` sql
|
||||
groupArraySorted(N)(column)
|
||||
```
|
||||
|
||||
**Arguments**
|
||||
|
||||
- `N` – The number of elements to return.
|
||||
|
||||
If the parameter is omitted, default value is the size of input.
|
||||
|
||||
- `column` – The value (Integer, String, Float and other Generic types).
|
||||
|
||||
**Example**
|
||||
|
||||
Gets the first 10 numbers:
|
||||
|
||||
``` sql
|
||||
SELECT groupArraySorted(10)(number) FROM numbers(100)
|
||||
```
|
||||
|
||||
``` text
|
||||
┌─groupArraySorted(10)(number)─┐
|
||||
│ [0,1,2,3,4,5,6,7,8,9] │
|
||||
└──────────────────────────────┘
|
||||
```
|
||||
|
||||
|
||||
Gets all the String implementations of all numbers in column:
|
||||
|
||||
``` sql
|
||||
SELECT groupArraySorted(str) FROM (SELECT toString(number) as str FROM numbers(5));
|
||||
|
||||
```
|
||||
|
||||
``` text
|
||||
┌─groupArraySorted(str)────────┐
|
||||
│ ['0','1','2','3','4'] │
|
||||
└──────────────────────────────┘
|
||||
```
|
||||
|
@ -54,6 +54,7 @@ ClickHouse-specific aggregate functions:
|
||||
- [groupArrayMovingAvg](/docs/en/sql-reference/aggregate-functions/reference/grouparraymovingavg.md)
|
||||
- [groupArrayMovingSum](/docs/en/sql-reference/aggregate-functions/reference/grouparraymovingsum.md)
|
||||
- [groupArraySample](./grouparraysample.md)
|
||||
- [groupArraySorted](/docs/en/sql-reference/aggregate-functions/reference/grouparraysorted.md)
|
||||
- [groupBitAnd](/docs/en/sql-reference/aggregate-functions/reference/groupbitand.md)
|
||||
- [groupBitOr](/docs/en/sql-reference/aggregate-functions/reference/groupbitor.md)
|
||||
- [groupBitXor](/docs/en/sql-reference/aggregate-functions/reference/groupbitxor.md)
|
||||
|
@ -291,8 +291,17 @@ public:
|
||||
const UInt64 size = value.size();
|
||||
checkArraySize(size, max_elems);
|
||||
writeVarUInt(size, buf);
|
||||
for (const auto & element : value)
|
||||
writeBinaryLittleEndian(element, buf);
|
||||
|
||||
|
||||
if constexpr (std::endian::native == std::endian::little)
|
||||
{
|
||||
buf.write(reinterpret_cast<const char *>(value.data()), size * sizeof(value[0]));
|
||||
}
|
||||
else
|
||||
{
|
||||
for (const auto & element : value)
|
||||
writeBinaryLittleEndian(element, buf);
|
||||
}
|
||||
|
||||
if constexpr (Trait::last)
|
||||
writeBinaryLittleEndian(this->data(place).total_values, buf);
|
||||
@ -315,8 +324,16 @@ public:
|
||||
auto & value = this->data(place).value;
|
||||
|
||||
value.resize_exact(size, arena);
|
||||
for (auto & element : value)
|
||||
readBinaryLittleEndian(element, buf);
|
||||
|
||||
if constexpr (std::endian::native == std::endian::little)
|
||||
{
|
||||
buf.readStrict(reinterpret_cast<char *>(value.data()), size * sizeof(value[0]));
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto & element : value)
|
||||
readBinaryLittleEndian(element, buf);
|
||||
}
|
||||
|
||||
if constexpr (Trait::last)
|
||||
readBinaryLittleEndian(this->data(place).total_values, buf);
|
||||
|
414
src/AggregateFunctions/AggregateFunctionGroupArraySorted.cpp
Normal file
414
src/AggregateFunctions/AggregateFunctionGroupArraySorted.cpp
Normal file
@ -0,0 +1,414 @@
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
#include <base/sort.h>
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include <Common/RadixSort.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/ArenaAllocator.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <IO/Operators.h>
|
||||
|
||||
#include <DataTypes/IDataType.h>
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
|
||||
#include <Columns/IColumn.h>
|
||||
#include <Columns/ColumnConst.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct Settings;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int TOO_LARGE_ARRAY_SIZE;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
enum class GroupArraySortedStrategy
|
||||
{
|
||||
heap,
|
||||
sort
|
||||
};
|
||||
|
||||
constexpr size_t group_array_sorted_sort_strategy_max_elements_threshold = 1000000;
|
||||
|
||||
template <typename T, GroupArraySortedStrategy strategy>
|
||||
struct GroupArraySortedData
|
||||
{
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;
|
||||
using Array = PODArray<T, 32, Allocator>;
|
||||
|
||||
static constexpr size_t partial_sort_max_elements_factor = 2;
|
||||
|
||||
static constexpr bool is_value_generic_field = std::is_same_v<T, Field>;
|
||||
|
||||
Array values;
|
||||
|
||||
static bool compare(const T & lhs, const T & rhs)
|
||||
{
|
||||
if constexpr (is_value_generic_field)
|
||||
{
|
||||
return lhs < rhs;
|
||||
}
|
||||
else
|
||||
{
|
||||
return CompareHelper<T>::less(lhs, rhs, -1);
|
||||
}
|
||||
}
|
||||
|
||||
struct Comparator
|
||||
{
|
||||
bool operator()(const T & lhs, const T & rhs)
|
||||
{
|
||||
return compare(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
ALWAYS_INLINE void heapReplaceTop()
|
||||
{
|
||||
size_t size = values.size();
|
||||
if (size < 2)
|
||||
return;
|
||||
|
||||
size_t child_index = 1;
|
||||
|
||||
if (values.size() > 2 && compare(values[1], values[2]))
|
||||
++child_index;
|
||||
|
||||
/// Check if we are in order
|
||||
if (compare(values[child_index], values[0]))
|
||||
return;
|
||||
|
||||
size_t current_index = 0;
|
||||
auto current = values[current_index];
|
||||
|
||||
do
|
||||
{
|
||||
/// We are not in heap-order, swap the parent with it's largest child.
|
||||
values[current_index] = values[child_index];
|
||||
current_index = child_index;
|
||||
|
||||
// Recompute the child based off of the updated parent
|
||||
child_index = 2 * child_index + 1;
|
||||
|
||||
if (child_index >= size)
|
||||
break;
|
||||
|
||||
if ((child_index + 1) < size && compare(values[child_index], values[child_index + 1]))
|
||||
{
|
||||
/// Right child exists and is greater than left child.
|
||||
++child_index;
|
||||
}
|
||||
|
||||
/// Check if we are in order.
|
||||
} while (!compare(values[child_index], current));
|
||||
|
||||
values[current_index] = current;
|
||||
}
|
||||
|
||||
ALWAYS_INLINE void sortAndLimit(size_t max_elements, Arena * arena)
|
||||
{
|
||||
if constexpr (is_value_generic_field)
|
||||
{
|
||||
::sort(values.begin(), values.end(), Comparator());
|
||||
}
|
||||
else
|
||||
{
|
||||
bool try_sort = trySort(values.begin(), values.end(), Comparator());
|
||||
if (!try_sort)
|
||||
RadixSort<RadixSortNumTraits<T>>::executeLSD(values.data(), values.size());
|
||||
}
|
||||
|
||||
if (values.size() > max_elements)
|
||||
values.resize(max_elements, arena);
|
||||
}
|
||||
|
||||
ALWAYS_INLINE void partialSortAndLimitIfNeeded(size_t max_elements, Arena * arena)
|
||||
{
|
||||
if (values.size() < max_elements * partial_sort_max_elements_factor)
|
||||
return;
|
||||
|
||||
::nth_element(values.begin(), values.begin() + max_elements, values.end(), Comparator());
|
||||
values.resize(max_elements, arena);
|
||||
}
|
||||
|
||||
ALWAYS_INLINE void addElement(T && element, size_t max_elements, Arena * arena)
|
||||
{
|
||||
if constexpr (strategy == GroupArraySortedStrategy::heap)
|
||||
{
|
||||
if (values.size() >= max_elements)
|
||||
{
|
||||
/// Element is greater or equal than current max element, it cannot be in k min elements
|
||||
if (!compare(element, values[0]))
|
||||
return;
|
||||
|
||||
values[0] = std::move(element);
|
||||
heapReplaceTop();
|
||||
return;
|
||||
}
|
||||
|
||||
values.push_back(std::move(element), arena);
|
||||
std::push_heap(values.begin(), values.end(), Comparator());
|
||||
}
|
||||
else
|
||||
{
|
||||
values.push_back(std::move(element), arena);
|
||||
partialSortAndLimitIfNeeded(max_elements, arena);
|
||||
}
|
||||
}
|
||||
|
||||
ALWAYS_INLINE void insertResultInto(IColumn & to, size_t max_elements, Arena * arena)
|
||||
{
|
||||
auto & result_array = assert_cast<ColumnArray &>(to);
|
||||
auto & result_array_offsets = result_array.getOffsets();
|
||||
|
||||
sortAndLimit(max_elements, arena);
|
||||
|
||||
result_array_offsets.push_back(result_array_offsets.back() + values.size());
|
||||
|
||||
if (values.empty())
|
||||
return;
|
||||
|
||||
if constexpr (is_value_generic_field)
|
||||
{
|
||||
auto & result_array_data = result_array.getData();
|
||||
for (auto & value : values)
|
||||
result_array_data.insert(value);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto & result_array_data = assert_cast<ColumnVector<T> &>(result_array.getData()).getData();
|
||||
|
||||
size_t result_array_data_insert_begin = result_array_data.size();
|
||||
result_array_data.resize(result_array_data_insert_begin + values.size());
|
||||
|
||||
for (size_t i = 0; i < values.size(); ++i)
|
||||
result_array_data[result_array_data_insert_begin + i] = values[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using GroupArraySortedDataHeap = GroupArraySortedData<T, GroupArraySortedStrategy::heap>;
|
||||
|
||||
template <typename T>
|
||||
using GroupArraySortedDataSort = GroupArraySortedData<T, GroupArraySortedStrategy::sort>;
|
||||
|
||||
constexpr UInt64 aggregate_function_group_array_sorted_max_element_size = 0xFFFFFF;
|
||||
|
||||
template <typename Data, typename T>
|
||||
class GroupArraySorted final
|
||||
: public IAggregateFunctionDataHelper<Data, GroupArraySorted<Data, T>>
|
||||
{
|
||||
public:
|
||||
explicit GroupArraySorted(
|
||||
const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elements_)
|
||||
: IAggregateFunctionDataHelper<Data, GroupArraySorted<Data, T>>(
|
||||
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
|
||||
, max_elements(max_elements_)
|
||||
, serialization(data_type_->getDefaultSerialization())
|
||||
{
|
||||
if (max_elements > aggregate_function_group_array_sorted_max_element_size)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS,
|
||||
"Too large limit parameter for groupArraySorted aggregate function, it should not exceed {}",
|
||||
aggregate_function_group_array_sorted_max_element_size);
|
||||
}
|
||||
|
||||
String getName() const override { return "groupArraySorted"; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
if constexpr (std::is_same_v<T, Field>)
|
||||
{
|
||||
auto row_value = (*columns[0])[row_num];
|
||||
this->data(place).addElement(std::move(row_value), max_elements, arena);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto row_value = assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
|
||||
this->data(place).addElement(std::move(row_value), max_elements, arena);
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & rhs_values = this->data(rhs).values;
|
||||
for (auto rhs_element : rhs_values)
|
||||
this->data(place).addElement(std::move(rhs_element), max_elements, arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
{
|
||||
auto & values = this->data(place).values;
|
||||
size_t size = values.size();
|
||||
writeVarUInt(size, buf);
|
||||
|
||||
if constexpr (std::is_same_v<T, Field>)
|
||||
{
|
||||
for (const Field & element : values)
|
||||
{
|
||||
if (element.isNull())
|
||||
{
|
||||
writeBinary(false, buf);
|
||||
}
|
||||
else
|
||||
{
|
||||
writeBinary(true, buf);
|
||||
serialization->serializeBinary(element, buf, {});
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr (std::endian::native == std::endian::little)
|
||||
{
|
||||
buf.write(reinterpret_cast<const char *>(values.data()), size * sizeof(values[0]));
|
||||
}
|
||||
else
|
||||
{
|
||||
for (const auto & element : values)
|
||||
writeBinaryLittleEndian(element, buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
|
||||
{
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
if (unlikely(size > max_elements))
|
||||
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size, it should not exceed {}", max_elements);
|
||||
|
||||
auto & values = this->data(place).values;
|
||||
values.resize_exact(size, arena);
|
||||
|
||||
if constexpr (std::is_same_v<T, Field>)
|
||||
{
|
||||
for (Field & element : values)
|
||||
{
|
||||
UInt8 is_null = 0;
|
||||
readBinary(is_null, buf);
|
||||
if (!is_null)
|
||||
serialization->deserializeBinary(element, buf, {});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr (std::endian::native == std::endian::little)
|
||||
{
|
||||
buf.readStrict(reinterpret_cast<char *>(values.data()), size * sizeof(values[0]));
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto & element : values)
|
||||
readBinaryLittleEndian(element, buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
|
||||
{
|
||||
this->data(place).insertResultInto(to, max_elements, arena);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override { return true; }
|
||||
|
||||
private:
|
||||
UInt64 max_elements;
|
||||
SerializationPtr serialization;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using GroupArraySortedHeap = GroupArraySorted<GroupArraySortedDataHeap<T>, T>;
|
||||
|
||||
template <typename T>
|
||||
using GroupArraySortedSort = GroupArraySorted<GroupArraySortedDataSort<T>, T>;
|
||||
|
||||
template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>
|
||||
AggregateFunctionPtr createWithNumericOrTimeType(const IDataType & argument_type, TArgs && ... args)
|
||||
{
|
||||
WhichDataType which(argument_type);
|
||||
|
||||
if (which.idx == TypeIndex::Date) return std::make_shared<AggregateFunctionTemplate<UInt16>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::DateTime) return std::make_shared<AggregateFunctionTemplate<UInt32>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::IPv4) return std::make_shared<AggregateFunctionTemplate<IPv4>>(std::forward<TArgs>(args)...);
|
||||
|
||||
return AggregateFunctionPtr(createWithNumericType<AggregateFunctionTemplate, TArgs...>(argument_type, std::forward<TArgs>(args)...));
|
||||
}
|
||||
|
||||
template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>
|
||||
inline AggregateFunctionPtr createAggregateFunctionGroupArraySortedImpl(const DataTypePtr & argument_type, const Array & parameters, TArgs ... args)
|
||||
{
|
||||
if (auto res = createWithNumericOrTimeType<AggregateFunctionTemplate>(*argument_type, argument_type, parameters, std::forward<TArgs>(args)...))
|
||||
return AggregateFunctionPtr(res);
|
||||
|
||||
return std::make_shared<AggregateFunctionTemplate<Field>>(argument_type, parameters, std::forward<TArgs>(args)...);
|
||||
}
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionGroupArray(
|
||||
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
|
||||
{
|
||||
assertUnary(name, argument_types);
|
||||
|
||||
UInt64 max_elems = std::numeric_limits<UInt64>::max();
|
||||
|
||||
if (parameters.empty())
|
||||
{
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should have limit argument", name);
|
||||
}
|
||||
else if (parameters.size() == 1)
|
||||
{
|
||||
auto type = parameters[0].getType();
|
||||
if (type != Field::Types::Int64 && type != Field::Types::UInt64)
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should be positive number", name);
|
||||
|
||||
if ((type == Field::Types::Int64 && parameters[0].get<Int64>() < 0) ||
|
||||
(type == Field::Types::UInt64 && parameters[0].get<UInt64>() == 0))
|
||||
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should be positive number", name);
|
||||
|
||||
max_elems = parameters[0].get<UInt64>();
|
||||
}
|
||||
else
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Function {} does not support this number of arguments", name);
|
||||
|
||||
if (max_elems > group_array_sorted_sort_strategy_max_elements_threshold)
|
||||
return createAggregateFunctionGroupArraySortedImpl<GroupArraySortedSort>(argument_types[0], parameters, max_elems);
|
||||
|
||||
return createAggregateFunctionGroupArraySortedImpl<GroupArraySortedHeap>(argument_types[0], parameters, max_elems);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionGroupArraySorted(AggregateFunctionFactory & factory)
|
||||
{
|
||||
AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = false };
|
||||
|
||||
factory.registerFunction("groupArraySorted", { createAggregateFunctionGroupArray, properties });
|
||||
}
|
||||
|
||||
}
|
@ -1,2 +1,5 @@
|
||||
clickhouse_add_executable (quantile-t-digest quantile-t-digest.cpp)
|
||||
target_link_libraries (quantile-t-digest PRIVATE dbms clickhouse_aggregate_functions)
|
||||
|
||||
clickhouse_add_executable (group_array_sorted group_array_sorted.cpp)
|
||||
target_link_libraries (group_array_sorted PRIVATE dbms clickhouse_aggregate_functions)
|
||||
|
205
src/AggregateFunctions/examples/group_array_sorted.cpp
Normal file
205
src/AggregateFunctions/examples/group_array_sorted.cpp
Normal file
@ -0,0 +1,205 @@
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <iostream>
|
||||
|
||||
#include "pcg_random.hpp"
|
||||
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Common/ArenaAllocator.h>
|
||||
#include <Common/RadixSort.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
|
||||
|
||||
using namespace DB;
|
||||
|
||||
template <typename T>
|
||||
struct GroupArraySortedDataHeap
|
||||
{
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;
|
||||
using Array = PODArray<T, 32, Allocator>;
|
||||
|
||||
Array values;
|
||||
|
||||
static bool compare(const T & lhs, const T & rhs)
|
||||
{
|
||||
return lhs < rhs;
|
||||
}
|
||||
|
||||
struct Comparator
|
||||
{
|
||||
bool operator()(const T & lhs, const T & rhs)
|
||||
{
|
||||
return compare(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
ALWAYS_INLINE void replaceTop()
|
||||
{
|
||||
size_t size = values.size();
|
||||
if (size < 2)
|
||||
return;
|
||||
|
||||
size_t child_index = 1;
|
||||
|
||||
if (values.size() > 2 && compare(values[1], values[2]))
|
||||
++child_index;
|
||||
|
||||
/// Check if we are in order
|
||||
if (compare(values[child_index], values[0]))
|
||||
return;
|
||||
|
||||
size_t current_index = 0;
|
||||
auto current = values[current_index];
|
||||
|
||||
do
|
||||
{
|
||||
/// We are not in heap-order, swap the parent with it's largest child.
|
||||
values[current_index] = values[child_index];
|
||||
current_index = child_index;
|
||||
|
||||
// Recompute the child based off of the updated parent
|
||||
child_index = 2 * child_index + 1;
|
||||
|
||||
if (child_index >= size)
|
||||
break;
|
||||
|
||||
if ((child_index + 1) < size && compare(values[child_index], values[child_index + 1]))
|
||||
{
|
||||
/// Right child exists and is greater than left child.
|
||||
++child_index;
|
||||
}
|
||||
|
||||
/// Check if we are in order.
|
||||
} while (!compare(values[child_index], current));
|
||||
|
||||
values[current_index] = current;
|
||||
}
|
||||
|
||||
ALWAYS_INLINE void addElement(const T & element, size_t max_elements, Arena * arena)
|
||||
{
|
||||
if (values.size() >= max_elements)
|
||||
{
|
||||
/// Element is greater or equal than current max element, it cannot be in k min elements
|
||||
if (!compare(element, values[0]))
|
||||
return;
|
||||
|
||||
values[0] = element;
|
||||
replaceTop();
|
||||
return;
|
||||
}
|
||||
|
||||
values.push_back(element, arena);
|
||||
std::push_heap(values.begin(), values.end(), Comparator());
|
||||
}
|
||||
|
||||
ALWAYS_INLINE void dump()
|
||||
{
|
||||
while (!values.empty())
|
||||
{
|
||||
std::pop_heap(values.begin(), values.end(), Comparator());
|
||||
std::cerr << values.back() << ' ';
|
||||
values.pop_back();
|
||||
}
|
||||
|
||||
std::cerr << '\n';
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GroupArraySortedDataSort
|
||||
{
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;
|
||||
using Array = PODArray<T, 32, Allocator>;
|
||||
|
||||
Array values;
|
||||
|
||||
static bool compare(const T & lhs, const T & rhs)
|
||||
{
|
||||
return lhs < rhs;
|
||||
}
|
||||
|
||||
struct Comparator
|
||||
{
|
||||
bool operator()(const T & lhs, const T & rhs)
|
||||
{
|
||||
return compare(lhs, rhs);
|
||||
}
|
||||
};
|
||||
|
||||
ALWAYS_INLINE void sortAndLimit(size_t max_elements, Arena * arena)
|
||||
{
|
||||
RadixSort<RadixSortNumTraits<T>>::executeLSD(values.data(), values.size());
|
||||
values.resize(max_elements, arena);
|
||||
}
|
||||
|
||||
ALWAYS_INLINE void partialSortAndLimitIfNeeded(size_t max_elements, Arena * arena)
|
||||
{
|
||||
if (values.size() < max_elements * 4)
|
||||
return;
|
||||
|
||||
std::nth_element(values.begin(), values.begin() + max_elements, values.end(), Comparator());
|
||||
values.resize(max_elements, arena);
|
||||
}
|
||||
|
||||
ALWAYS_INLINE void addElement(const T & element, size_t max_elements, Arena * arena)
|
||||
{
|
||||
values.push_back(element, arena);
|
||||
partialSortAndLimitIfNeeded(max_elements, arena);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SortedData>
|
||||
NO_INLINE void benchmark(size_t elements, size_t max_elements)
|
||||
{
|
||||
Stopwatch watch;
|
||||
watch.start();
|
||||
|
||||
SortedData data;
|
||||
pcg64_fast rng;
|
||||
|
||||
Arena arena;
|
||||
|
||||
for (size_t i = 0; i < elements; ++i)
|
||||
{
|
||||
uint64_t value = rng();
|
||||
data.addElement(value, max_elements, &arena);
|
||||
}
|
||||
|
||||
watch.stop();
|
||||
std::cerr << "Elapsed " << watch.elapsedMilliseconds() << " milliseconds" << '\n';
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv)
|
||||
{
|
||||
(void)(argc);
|
||||
(void)(argv);
|
||||
|
||||
if (argc != 4)
|
||||
{
|
||||
std::cerr << "./group_array_sorted method elements max_elements" << '\n';
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string method = std::string(argv[1]);
|
||||
uint64_t elements = std::atol(argv[2]); /// NOLINT
|
||||
uint64_t max_elements = std::atol(argv[3]); /// NOLINT
|
||||
|
||||
std::cerr << "Method " << method << " elements " << elements << " max elements " << max_elements << '\n';
|
||||
|
||||
if (method == "heap")
|
||||
{
|
||||
benchmark<GroupArraySortedDataHeap<UInt64>>(elements, max_elements);
|
||||
}
|
||||
else if (method == "sort")
|
||||
{
|
||||
benchmark<GroupArraySortedDataSort<UInt64>>(elements, max_elements);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Invalid method " << method << '\n';
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
@ -15,6 +15,7 @@ void registerAggregateFunctionCount(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionDeltaSum(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionDeltaSumTimestamp(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionGroupArray(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionGroupArraySorted(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionGroupUniqArray(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionGroupArrayInsertAt(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionsQuantile(AggregateFunctionFactory &);
|
||||
@ -112,6 +113,7 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionDeltaSum(factory);
|
||||
registerAggregateFunctionDeltaSumTimestamp(factory);
|
||||
registerAggregateFunctionGroupArray(factory);
|
||||
registerAggregateFunctionGroupArraySorted(factory);
|
||||
registerAggregateFunctionGroupUniqArray(factory);
|
||||
registerAggregateFunctionGroupArrayInsertAt(factory);
|
||||
registerAggregateFunctionsQuantile(factory);
|
||||
|
31
tests/performance/group_array_sorted.xml
Normal file
31
tests/performance/group_array_sorted.xml
Normal file
@ -0,0 +1,31 @@
|
||||
<test>
|
||||
<settings>
|
||||
<max_memory_usage>30000000000</max_memory_usage>
|
||||
</settings>
|
||||
|
||||
<substitutions>
|
||||
<substitution>
|
||||
<name>millions</name>
|
||||
<values>
|
||||
<value>50</value>
|
||||
<value>100</value>
|
||||
</values>
|
||||
</substitution>
|
||||
<substitution>
|
||||
<name>window</name>
|
||||
<values>
|
||||
<value>10</value>
|
||||
<value>1000</value>
|
||||
<value>10000</value>
|
||||
</values>
|
||||
</substitution>
|
||||
</substitutions>
|
||||
|
||||
<create_query>create table sorted_{millions}m engine MergeTree order by k as select number % 100 k, rand() v from numbers_mt(1000000 * {millions})</create_query>
|
||||
<create_query>optimize table sorted_{millions}m final</create_query>
|
||||
|
||||
<query>select k, groupArraySorted({window})(v) from sorted_{millions}m group by k format Null</query>
|
||||
<query>select k % 10 kk, groupArraySorted({window})(v) from sorted_{millions}m group by kk format Null</query>
|
||||
|
||||
<drop_query>drop table if exists sorted_{millions}m</drop_query>
|
||||
</test>
|
12
tests/queries/0_stateless/02841_group_array_sorted.reference
Normal file
12
tests/queries/0_stateless/02841_group_array_sorted.reference
Normal file
@ -0,0 +1,12 @@
|
||||
[0,1,2,3,4]
|
||||
[0,1,2,3,4]
|
||||
[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99]
|
||||
['0','1','10','11','12','13','14','15','16','17','18','19','2','20','21','22','23','24','25','26','27','28','29','3','4','5','6','7','8','9']
|
||||
[0,0,1,1,2,2,3,3,4,4]
|
||||
[[1,2,3,4],[2,3,4,5],[3,4,5,6]]
|
||||
[(2,1),(15,25),(30,60),(100,200)]
|
||||
[0.2,2.2,6.6,12.5]
|
||||
['AAA','Aaa','aaa','abc','bbc']
|
||||
1000000
|
||||
1000000
|
||||
[0,1]
|
41
tests/queries/0_stateless/02841_group_array_sorted.sql
Normal file
41
tests/queries/0_stateless/02841_group_array_sorted.sql
Normal file
@ -0,0 +1,41 @@
|
||||
SELECT groupArraySorted(5)(number) FROM numbers(100);
|
||||
|
||||
SELECT groupArraySorted(10)(number) FROM numbers(5);
|
||||
|
||||
SELECT groupArraySorted(100)(number) FROM numbers(1000);
|
||||
|
||||
SELECT groupArraySorted(30)(str) FROM (SELECT toString(number) as str FROM numbers(30));
|
||||
|
||||
SELECT groupArraySorted(10)(toInt64(number/2)) FROM numbers(100);
|
||||
|
||||
DROP TABLE IF EXISTS test;
|
||||
CREATE TABLE test (a Array(UInt64)) engine=MergeTree ORDER BY a;
|
||||
INSERT INTO test VALUES ([3,4,5,6]), ([1,2,3,4]), ([2,3,4,5]);
|
||||
SELECT groupArraySorted(3)(a) FROM test;
|
||||
DROP TABLE test;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS test (id Int32, data Tuple(Int32, Int32)) ENGINE = MergeTree() ORDER BY id;
|
||||
INSERT INTO test (id, data) VALUES (1, (100, 200)), (2, (15, 25)), (3, (2, 1)), (4, (30, 60));
|
||||
SELECT groupArraySorted(4)(data) FROM test;
|
||||
DROP TABLE test;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS test (id Int32, data Decimal32(2)) ENGINE = MergeTree() ORDER BY id;
|
||||
INSERT INTO test (id, data) VALUES (1, 12.5), (2, 0.2), (3, 6.6), (4, 2.2);
|
||||
SELECT groupArraySorted(4)(data) FROM test;
|
||||
DROP TABLE test;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS test (id Int32, data FixedString(3)) ENGINE = MergeTree() ORDER BY id;
|
||||
INSERT INTO test (id, data) VALUES (1, 'AAA'), (2, 'bbc'), (3, 'abc'), (4, 'aaa'), (5, 'Aaa');
|
||||
SELECT groupArraySorted(5)(data) FROM test;
|
||||
DROP TABLE test;
|
||||
|
||||
CREATE TABLE test (id Decimal(76, 53), str String) ENGINE = MergeTree ORDER BY id;
|
||||
INSERT INTO test SELECT number, 'test' FROM numbers(1000000);
|
||||
SELECT count(id) FROM test;
|
||||
SELECT count(concat(toString(id), 'a')) FROM test;
|
||||
DROP TABLE test;
|
||||
|
||||
CREATE TABLE test (id UInt64, agg AggregateFunction(groupArraySorted(2), UInt64)) engine=MergeTree ORDER BY id;
|
||||
INSERT INTO test SELECT 1, groupArraySortedState(2)(number) FROM numbers(10);
|
||||
SELECT groupArraySortedMerge(2)(agg) FROM test;
|
||||
DROP TABLE test;
|
@ -1593,6 +1593,7 @@ groupArrayLast
|
||||
groupArrayMovingAvg
|
||||
groupArrayMovingSum
|
||||
groupArraySample
|
||||
groupArraySorted
|
||||
groupBitAnd
|
||||
groupBitOr
|
||||
groupBitXor
|
||||
@ -1607,6 +1608,7 @@ grouparraylast
|
||||
grouparraymovingavg
|
||||
grouparraymovingsum
|
||||
grouparraysample
|
||||
grouparraysorted
|
||||
groupbitand
|
||||
groupbitmap
|
||||
groupbitmapand
|
||||
|
Loading…
Reference in New Issue
Block a user