Revert "Revert "Add new aggregation function groupArraySorted()""

This commit is contained in:
Maksim Kita 2024-01-19 18:37:02 +03:00
parent c339a74ac3
commit 20c1f0c18f
9 changed files with 574 additions and 0 deletions

View File

@ -0,0 +1,48 @@
---
toc_priority: 112
---
# groupArraySorted {#groupArraySorted}
Returns an array with the first N items in ascending order.
``` sql
groupArraySorted(N)(column)
```
**Arguments**
- `N` The number of elements to return.
If the parameter is omitted, default value is the size of input.
- `column` The value (Integer, String, Float and other Generic types).
**Example**
Gets the first 10 numbers:
``` sql
SELECT groupArraySorted(10)(number) FROM numbers(100)
```
``` text
┌─groupArraySorted(10)(number)─┐
│ [0,1,2,3,4,5,6,7,8,9] │
└──────────────────────────────┘
```
Gets all the String implementations of all numbers in column:
``` sql
SELECT groupArraySorted(str) FROM (SELECT toString(number) as str FROM numbers(5));
```
``` text
┌─groupArraySorted(str)────────┐
│ ['0','1','2','3','4'] │
└──────────────────────────────┘
```

View File

@ -54,6 +54,7 @@ ClickHouse-specific aggregate functions:
- [groupArrayMovingAvg](/docs/en/sql-reference/aggregate-functions/reference/grouparraymovingavg.md)
- [groupArrayMovingSum](/docs/en/sql-reference/aggregate-functions/reference/grouparraymovingsum.md)
- [groupArraySample](./grouparraysample.md)
- [groupArraySorted](/docs/en/sql-reference/aggregate-functions/reference/grouparraysorted.md)
- [groupBitAnd](/docs/en/sql-reference/aggregate-functions/reference/groupbitand.md)
- [groupBitOr](/docs/en/sql-reference/aggregate-functions/reference/groupbitor.md)
- [groupBitXor](/docs/en/sql-reference/aggregate-functions/reference/groupbitxor.md)

View File

@ -0,0 +1,82 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionGroupArraySorted.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <Common/Exception.h>
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int BAD_ARGUMENTS;
}
namespace
{
template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>
AggregateFunctionPtr createWithNumericOrTimeType(const IDataType & argument_type, TArgs && ... args)
{
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Date) return std::make_shared<AggregateFunctionTemplate<UInt16>>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::DateTime) return std::make_shared<AggregateFunctionTemplate<UInt32>>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::IPv4) return std::make_shared<AggregateFunctionTemplate<IPv4>>(std::forward<TArgs>(args)...);
return AggregateFunctionPtr(createWithNumericType<AggregateFunctionTemplate, TArgs...>(argument_type, std::forward<TArgs>(args)...));
}
template <typename ... TArgs>
inline AggregateFunctionPtr createAggregateFunctionGroupArraySortedImpl(const DataTypePtr & argument_type, const Array & parameters, TArgs ... args)
{
if (auto res = createWithNumericOrTimeType<GroupArraySortedNumericImpl>(*argument_type, argument_type, parameters, std::forward<TArgs>(args)...))
return AggregateFunctionPtr(res);
WhichDataType which(argument_type);
return std::make_shared<GroupArraySortedGeneralImpl<GroupArraySortedNodeGeneral>>(argument_type, parameters, std::forward<TArgs>(args)...);
}
AggregateFunctionPtr createAggregateFunctionGroupArraySorted(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
{
assertUnary(name, argument_types);
UInt64 max_elems = std::numeric_limits<UInt64>::max();
if (parameters.empty())
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should have limit argument", name);
}
else if (parameters.size() == 1)
{
auto type = parameters[0].getType();
if (type != Field::Types::Int64 && type != Field::Types::UInt64)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should be positive number", name);
if ((type == Field::Types::Int64 && parameters[0].get<Int64>() < 0) ||
(type == Field::Types::UInt64 && parameters[0].get<UInt64>() == 0))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should be positive number", name);
max_elems = parameters[0].get<UInt64>();
}
else
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} does not support this number of arguments", name);
return createAggregateFunctionGroupArraySortedImpl(argument_types[0], parameters, max_elems);
}
}
void registerAggregateFunctionGroupArraySorted(AggregateFunctionFactory & factory)
{
AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = false };
factory.registerFunction("groupArraySorted", { createAggregateFunctionGroupArraySorted, properties });
}
}

View File

@ -0,0 +1,355 @@
#pragma once
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <IO/ReadBufferFromString.h>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <Functions/array/arraySort.h>
#include <Common/Exception.h>
#include <Common/ArenaAllocator.h>
#include <Common/assert_cast.h>
#include <Columns/ColumnConst.h>
#include <DataTypes/IDataType.h>
#include <base/sort.h>
#include <Columns/IColumn.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Common/RadixSort.h>
#include <algorithm>
#include <type_traits>
#include <utility>
#define AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE 0xFFFFFF
namespace DB
{
struct Settings;
namespace ErrorCodes
{
extern const int TOO_LARGE_ARRAY_SIZE;
}
template <typename T>
struct GroupArraySortedData;
template <typename T>
struct GroupArraySortedData
{
/// For easy serialization.
static_assert(std::has_unique_object_representations_v<T> || std::is_floating_point_v<T>);
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;
using Array = PODArray<T, 32, Allocator>;
Array value;
};
template <typename T>
class GroupArraySortedNumericImpl final
: public IAggregateFunctionDataHelper<GroupArraySortedData<T>, GroupArraySortedNumericImpl<T>>
{
using Data = GroupArraySortedData<T>;
UInt64 max_elems;
SerializationPtr serialization;
public:
explicit GroupArraySortedNumericImpl(
const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<GroupArraySortedData<T>, GroupArraySortedNumericImpl<T>>(
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
, max_elems(max_elems_)
, serialization(data_type_->getDefaultSerialization())
{
}
String getName() const override { return "groupArraySorted"; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
const auto & row_value = assert_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
auto & cur_elems = this->data(place);
cur_elems.value.push_back(row_value, arena);
/// To optimize, we sort (2 * max_size) elements of input array over and over again
/// and after each loop we delete the last half of sorted array
if (cur_elems.value.size() >= max_elems * 2)
{
RadixSort<RadixSortNumTraits<T>>::executeLSD(cur_elems.value.data(), cur_elems.value.size());
cur_elems.value.resize(max_elems, arena);
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
auto & cur_elems = this->data(place);
auto & rhs_elems = this->data(rhs);
if (rhs_elems.value.empty())
return;
if (rhs_elems.value.size())
cur_elems.value.insertByOffsets(rhs_elems.value, 0, rhs_elems.value.size(), arena);
RadixSort<RadixSortNumTraits<T>>::executeLSD(cur_elems.value.data(), cur_elems.value.size());
size_t elems_size = cur_elems.value.size() < max_elems ? cur_elems.value.size() : max_elems;
cur_elems.value.resize(elems_size, arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
auto & value = this->data(place).value;
size_t size = value.size();
writeVarUInt(size, buf);
for (const auto & elem : value)
writeBinaryLittleEndian(elem, buf);
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
size_t size = 0;
readVarUInt(size, buf);
if (unlikely(size > max_elems))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size, it should not exceed {}", max_elems);
auto & value = this->data(place).value;
value.resize(size, arena);
for (auto & element : value)
readBinaryLittleEndian(element, buf);
}
static void checkArraySize(size_t elems, size_t max_elems)
{
if (unlikely(elems > max_elems))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
"Too large array size {} (maximum: {})", elems, max_elems);
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
{
auto& value = this->data(place).value;
RadixSort<RadixSortNumTraits<T>>::executeLSD(value.data(), value.size());
size_t elems_size = value.size() < max_elems ? value.size() : max_elems;
value.resize(elems_size, arena);
size_t size = value.size();
ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
offsets_to.push_back(offsets_to.back() + size);
if (size)
{
typename ColumnVector<T>::Container & data_to = assert_cast<ColumnVector<T> &>(arr_to.getData()).getData();
data_to.insert(this->data(place).value.begin(), this->data(place).value.end());
RadixSort<RadixSortNumTraits<T>>::executeLSD(value.data(), value.size());
value.resize(elems_size, arena);
}
}
bool allocatesMemoryInArena() const override { return true; }
};
template <typename Node, bool has_sampler>
struct GroupArraySortedGeneralData;
template <typename Node>
struct GroupArraySortedGeneralData<Node, false>
{
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
using Allocator = MixedAlignedArenaAllocator<alignof(Node *), 4096>;
using Array = PODArray<Field, 32, Allocator>;
Array value;
};
template <typename Node>
struct GroupArraySortedNodeBase
{
UInt64 size; // size of payload
/// Returns pointer to actual payload
char * data() { return reinterpret_cast<char *>(this) + sizeof(Node); }
const char * data() const { return reinterpret_cast<const char *>(this) + sizeof(Node); }
};
struct GroupArraySortedNodeString : public GroupArraySortedNodeBase<GroupArraySortedNodeString>
{
using Node = GroupArraySortedNodeString;
};
struct GroupArraySortedNodeGeneral : public GroupArraySortedNodeBase<GroupArraySortedNodeGeneral>
{
using Node = GroupArraySortedNodeGeneral;
};
/// Implementation of groupArraySorted for Generic data via Array
template <typename Node>
class GroupArraySortedGeneralImpl final
: public IAggregateFunctionDataHelper<GroupArraySortedGeneralData<Node, false>, GroupArraySortedGeneralImpl<Node>>
{
using Data = GroupArraySortedGeneralData<Node, false>;
static Data & data(AggregateDataPtr __restrict place) { return *reinterpret_cast<Data *>(place); }
static const Data & data(ConstAggregateDataPtr __restrict place) { return *reinterpret_cast<const Data *>(place); }
DataTypePtr & data_type;
UInt64 max_elems;
SerializationPtr serialization;
public:
GroupArraySortedGeneralImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<GroupArraySortedGeneralData<Node, false>, GroupArraySortedGeneralImpl<Node>>(
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
, data_type(this->argument_types[0])
, max_elems(max_elems_)
, serialization(data_type->getDefaultSerialization())
{
}
String getName() const override { return "groupArraySorted"; }
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
auto & cur_elems = data(place);
cur_elems.value.push_back(columns[0][0][row_num], arena);
/// To optimize, we sort (2 * max_size) elements of input array over and over again and
/// after each loop we delete the last half of sorted array
if (cur_elems.value.size() >= max_elems * 2)
{
std::sort(cur_elems.value.begin(), cur_elems.value.begin() + (max_elems * 2));
cur_elems.value.erase(cur_elems.value.begin() + max_elems, cur_elems.value.begin() + (max_elems * 2));
}
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
auto & cur_elems = data(place);
auto & rhs_elems = data(rhs);
if (rhs_elems.value.empty())
return;
UInt64 new_elems = rhs_elems.value.size();
for (UInt64 i = 0; i < new_elems; ++i)
cur_elems.value.push_back(rhs_elems.value[i], arena);
checkArraySize(cur_elems.value.size(), AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE);
if (!cur_elems.value.empty())
{
std::sort(cur_elems.value.begin(), cur_elems.value.end());
if (cur_elems.value.size() > max_elems)
cur_elems.value.resize(max_elems, arena);
}
}
static void checkArraySize(size_t elems, size_t max_elems)
{
if (unlikely(elems > max_elems))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE,
"Too large array size {} (maximum: {})", elems, max_elems);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
{
auto & value = data(place).value;
size_t size = value.size();
checkArraySize(size, AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE);
writeVarUInt(size, buf);
for (const Field & elem : value)
{
if (elem.isNull())
{
writeBinary(false, buf);
}
else
{
writeBinary(true, buf);
serialization->serializeBinary(elem, buf, {});
}
}
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
{
size_t size = 0;
readVarUInt(size, buf);
if (unlikely(size > max_elems))
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size, it should not exceed {}", max_elems);
checkArraySize(size, AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ELEMENT_SIZE);
auto & value = data(place).value;
value.resize(size, arena);
for (Field & elem : value)
{
UInt8 is_null = 0;
readBinary(is_null, buf);
if (!is_null)
serialization->deserializeBinary(elem, buf, {});
}
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
{
auto & column_array = assert_cast<ColumnArray &>(to);
auto & value = data(place).value;
if (!value.empty())
{
std::sort(value.begin(), value.end());
if (value.size() > max_elems)
value.resize_exact(max_elems, arena);
}
auto & offsets = column_array.getOffsets();
offsets.push_back(offsets.back() + value.size());
auto & column_data = column_array.getData();
if (std::is_same_v<Node, GroupArraySortedNodeString>)
{
auto & string_offsets = assert_cast<ColumnString &>(column_data).getOffsets();
string_offsets.reserve(string_offsets.size() + value.size());
}
for (const Field& field : value)
column_data.insert(field);
}
bool allocatesMemoryInArena() const override { return true; }
};
#undef AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE
}

View File

@ -15,6 +15,7 @@ void registerAggregateFunctionCount(AggregateFunctionFactory &);
void registerAggregateFunctionDeltaSum(AggregateFunctionFactory &);
void registerAggregateFunctionDeltaSumTimestamp(AggregateFunctionFactory &);
void registerAggregateFunctionGroupArray(AggregateFunctionFactory &);
void registerAggregateFunctionGroupArraySorted(AggregateFunctionFactory & factory);
void registerAggregateFunctionGroupUniqArray(AggregateFunctionFactory &);
void registerAggregateFunctionGroupArrayInsertAt(AggregateFunctionFactory &);
void registerAggregateFunctionsQuantile(AggregateFunctionFactory &);
@ -112,6 +113,7 @@ void registerAggregateFunctions()
registerAggregateFunctionDeltaSum(factory);
registerAggregateFunctionDeltaSumTimestamp(factory);
registerAggregateFunctionGroupArray(factory);
registerAggregateFunctionGroupArraySorted(factory);
registerAggregateFunctionGroupUniqArray(factory);
registerAggregateFunctionGroupArrayInsertAt(factory);
registerAggregateFunctionsQuantile(factory);

View File

@ -0,0 +1,31 @@
<test>
<settings>
<max_memory_usage>30000000000</max_memory_usage>
</settings>
<substitutions>
<substitution>
<name>millions</name>
<values>
<value>50</value>
<value>100</value>
</values>
</substitution>
<substitution>
<name>window</name>
<values>
<value>10</value>
<value>1000</value>
<value>10000</value>
</values>
</substitution>
</substitutions>
<create_query>create table sorted_{millions}m engine MergeTree order by k as select number % 100 k, rand() v from numbers_mt(1000000 * {millions})</create_query>
<create_query>optimize table sorted_{millions}m final</create_query>
<query>select k, groupArraySorted({window})(v) from sorted_{millions}m group by k format Null</query>
<query>select k % 10 kk, groupArraySorted({window})(v) from sorted_{millions}m group by kk format Null</query>
<drop_query>drop table if exists sorted_{millions}m</drop_query>
</test>

View File

@ -0,0 +1,12 @@
[0,1,2,3,4]
[0,1,2,3,4]
[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99]
['0','1','10','11','12','13','14','15','16','17','18','19','2','20','21','22','23','24','25','26','27','28','29','3','4','5','6','7','8','9']
[0,0,1,1,2,2,3,3,4,4]
[[1,2,3,4],[2,3,4,5],[3,4,5,6]]
[(2,1),(15,25),(30,60),(100,200)]
[0.2,2.2,6.6,12.5]
['AAA','Aaa','aaa','abc','bbc']
1000000
1000000
[0,1]

View File

@ -0,0 +1,41 @@
SELECT groupArraySorted(5)(number) FROM numbers(100);
SELECT groupArraySorted(10)(number) FROM numbers(5);
SELECT groupArraySorted(100)(number) FROM numbers(1000);
SELECT groupArraySorted(30)(str) FROM (SELECT toString(number) as str FROM numbers(30));
SELECT groupArraySorted(10)(toInt64(number/2)) FROM numbers(100);
DROP TABLE IF EXISTS test;
CREATE TABLE test (a Array(UInt64)) engine=MergeTree ORDER BY a;
INSERT INTO test VALUES ([3,4,5,6]), ([1,2,3,4]), ([2,3,4,5]);
SELECT groupArraySorted(3)(a) FROM test;
DROP TABLE test;
CREATE TABLE IF NOT EXISTS test (id Int32, data Tuple(Int32, Int32)) ENGINE = MergeTree() ORDER BY id;
INSERT INTO test (id, data) VALUES (1, (100, 200)), (2, (15, 25)), (3, (2, 1)), (4, (30, 60));
SELECT groupArraySorted(4)(data) FROM test;
DROP TABLE test;
CREATE TABLE IF NOT EXISTS test (id Int32, data Decimal32(2)) ENGINE = MergeTree() ORDER BY id;
INSERT INTO test (id, data) VALUES (1, 12.5), (2, 0.2), (3, 6.6), (4, 2.2);
SELECT groupArraySorted(4)(data) FROM test;
DROP TABLE test;
CREATE TABLE IF NOT EXISTS test (id Int32, data FixedString(3)) ENGINE = MergeTree() ORDER BY id;
INSERT INTO test (id, data) VALUES (1, 'AAA'), (2, 'bbc'), (3, 'abc'), (4, 'aaa'), (5, 'Aaa');
SELECT groupArraySorted(5)(data) FROM test;
DROP TABLE test;
CREATE TABLE test (id Decimal(76, 53), str String) ENGINE = MergeTree ORDER BY id;
INSERT INTO test SELECT number, 'test' FROM numbers(1000000);
SELECT count(id) FROM test;
SELECT count(concat(toString(id), 'a')) FROM test;
DROP TABLE test;
CREATE TABLE test (id UInt64, agg AggregateFunction(groupArraySorted(2), UInt64)) engine=MergeTree ORDER BY id;
INSERT INTO test SELECT 1, groupArraySortedState(2)(number) FROM numbers(10);
SELECT groupArraySortedMerge(2)(agg) FROM test;
DROP TABLE test;

View File

@ -1593,6 +1593,7 @@ groupArrayLast
groupArrayMovingAvg
groupArrayMovingSum
groupArraySample
groupArraySorted
groupBitAnd
groupBitOr
groupBitXor
@ -1607,6 +1608,7 @@ grouparraylast
grouparraymovingavg
grouparraymovingsum
grouparraysample
grouparraysorted
groupbitand
groupbitmap
groupbitmapand