rewrite function to be independent of groupArray, add test

This commit is contained in:
yariks5s 2023-11-24 18:55:15 +00:00
parent 69205769d0
commit cc8ac432dd
4 changed files with 51 additions and 31 deletions

View File

@ -19,24 +19,24 @@ namespace ErrorCodes
namespace
{
template <template <typename, typename> class AggregateFunctionTemplate, typename Data, typename ... TArgs>
template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>
AggregateFunctionPtr createWithNumericOrTimeType(const IDataType & argument_type, TArgs && ... args)
{
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Date) return std::make_shared<AggregateFunctionTemplate<UInt16, Data>>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::DateTime) return std::make_shared<AggregateFunctionTemplate<UInt32, Data>>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::IPv4) return std::make_shared<AggregateFunctionTemplate<IPv4, Data>>(std::forward<TArgs>(args)...);
return AggregateFunctionPtr(createWithNumericType<AggregateFunctionTemplate, Data, TArgs...>(argument_type, std::forward<TArgs>(args)...));
if (which.idx == TypeIndex::Date) return std::make_shared<AggregateFunctionTemplate<UInt16>>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::DateTime) return std::make_shared<AggregateFunctionTemplate<UInt32>>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::IPv4) return std::make_shared<AggregateFunctionTemplate<IPv4>>(std::forward<TArgs>(args)...);
return AggregateFunctionPtr(createWithNumericType<AggregateFunctionTemplate, TArgs...>(argument_type, std::forward<TArgs>(args)...));
}
template <typename Trait, typename ... TArgs>
template <typename ... TArgs>
inline AggregateFunctionPtr createAggregateFunctionGroupArraySortedImpl(const DataTypePtr & argument_type, const Array & parameters, TArgs ... args)
{
if (auto res = createWithNumericOrTimeType<GroupArraySortedNumericImpl, Trait>(*argument_type, argument_type, parameters, std::forward<TArgs>(args)...))
if (auto res = createWithNumericOrTimeType<GroupArraySortedNumericImpl>(*argument_type, argument_type, parameters, std::forward<TArgs>(args)...))
return AggregateFunctionPtr(res);
WhichDataType which(argument_type);
return std::make_shared<GroupArraySortedGeneralImpl<GroupArrayNodeGeneral>>(argument_type, parameters, std::forward<TArgs>(args)...);
return std::make_shared<GroupArraySortedGeneralImpl<GroupArraySortedNodeGeneral>>(argument_type, parameters, std::forward<TArgs>(args)...);
}
AggregateFunctionPtr createAggregateFunctionGroupArraySorted(
@ -66,7 +66,7 @@ AggregateFunctionPtr createAggregateFunctionGroupArraySorted(
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} does not support this number of arguments", name);
return createAggregateFunctionGroupArraySortedImpl<GroupArrayTrait</* Thas_limit= */ true, false, /* Tsampler= */ Sampler::NONE>>(argument_types[0], parameters, max_elems);
return createAggregateFunctionGroupArraySortedImpl(argument_types[0], parameters, max_elems);
}
}

View File

@ -13,7 +13,6 @@
#include <Columns/ColumnArray.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <AggregateFunctions/AggregateFunctionGroupArray.h>
#include <Functions/array/arraySort.h>
#include <Common/Exception.h>
@ -43,10 +42,10 @@ namespace ErrorCodes
}
template <typename T>
struct GroupArraySortedNumericData;
struct GroupArraySortedData;
template <typename T>
struct GroupArraySortedNumericData
struct GroupArraySortedData
{
/// For easy serialization.
static_assert(std::has_unique_object_representations_v<T> || std::is_floating_point_v<T>);
@ -58,19 +57,18 @@ struct GroupArraySortedNumericData
Array value;
};
template <typename T, typename Trait>
template <typename T>
class GroupArraySortedNumericImpl final
: public IAggregateFunctionDataHelper<GroupArraySortedNumericData<T>, GroupArraySortedNumericImpl<T, Trait>>
: public IAggregateFunctionDataHelper<GroupArraySortedData<T>, GroupArraySortedNumericImpl<T>>
{
using Data = GroupArraySortedNumericData<T>;
static constexpr bool limit_num_elems = Trait::has_limit;
using Data = GroupArraySortedData<T>;
UInt64 max_elems;
SerializationPtr serialization;
public:
explicit GroupArraySortedNumericImpl(
const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<GroupArraySortedNumericData<T>, GroupArraySortedNumericImpl<T, Trait>>(
: IAggregateFunctionDataHelper<GroupArraySortedData<T>, GroupArraySortedNumericImpl<T>>(
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
, max_elems(max_elems_)
, serialization(data_type_->getDefaultSerialization())
@ -105,12 +103,11 @@ public:
if (rhs_elems.value.size())
cur_elems.value.insertByOffsets(rhs_elems.value, 0, rhs_elems.value.size(), arena);
if (cur_elems.value.size() < max_elems)
throw Exception(ErrorCodes::INCORRECT_DATA, "The max size of result array is bigger than the actual array size");
RadixSort<RadixSortNumTraits<T>>::executeLSD(cur_elems.value.data(), cur_elems.value.size());
if (limit_num_elems)
cur_elems.value.resize(max_elems, arena);
size_t elems_size = cur_elems.value.size() < max_elems ? cur_elems.value.size() : max_elems;
cur_elems.value.resize(elems_size, arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
@ -148,11 +145,10 @@ public:
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
{
auto& value = this->data(place).value;
if (value.size() < max_elems)
throw Exception(ErrorCodes::INCORRECT_DATA, "The max size of result array is bigger than the actual array size");
RadixSort<RadixSortNumTraits<T>>::executeLSD(value.data(), value.size());
if (limit_num_elems)
value.resize(max_elems, arena);
size_t elems_size = value.size() < max_elems ? value.size() : max_elems;
value.resize(elems_size, arena);
size_t size = value.size();
ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
@ -164,6 +160,8 @@ public:
{
typename ColumnVector<T>::Container & data_to = assert_cast<ColumnVector<T> &>(arr_to.getData()).getData();
data_to.insert(this->data(place).value.begin(), this->data(place).value.end());
RadixSort<RadixSortNumTraits<T>>::executeLSD(value.data(), value.size());
value.resize(elems_size, arena);
}
}
@ -185,8 +183,27 @@ struct GroupArraySortedGeneralData<Node, false>
};
template <typename Node>
struct GroupArraySortedGeneralData<Node, true> : public GroupArraySamplerData<Node *>
struct GroupArraySortedNodeBase
{
UInt64 size; // size of payload
/// Returns pointer to actual payload
char * data() { return reinterpret_cast<char *>(this) + sizeof(Node); }
const char * data() const { return reinterpret_cast<const char *>(this) + sizeof(Node); }
};
struct GroupArraySortedNodeString : public GroupArraySortedNodeBase<GroupArraySortedNodeString>
{
using Node = GroupArraySortedNodeString;
};
struct GroupArraySortedNodeGeneral : public GroupArraySortedNodeBase<GroupArraySortedNodeGeneral>
{
using Node = GroupArraySortedNodeGeneral;
};
/// Implementation of groupArraySorted for Generic data via Array
@ -303,7 +320,7 @@ public:
auto & column_data = column_array.getData();
if (std::is_same_v<Node, GroupArrayNodeString>)
if (std::is_same_v<Node, GroupArraySortedNodeString>)
{
auto & string_offsets = assert_cast<ColumnString &>(column_data).getOffsets();
string_offsets.reserve(string_offsets.size() + value.size());

View File

@ -1,4 +1,5 @@
[0,1,2,3,4]
[0,1,2,3,4]
[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99]
['0','1','10','11','12','13','14','15','16','17','18','19','2','20','21','22','23','24','25','26','27','28','29','3','4','5','6','7','8','9']
[0,0,1,1,2,2,3,3,4,4]

View File

@ -1,6 +1,8 @@
SELECT groupArraySorted(5)(number) from numbers(100);
SELECT groupArraySorted(5)(number) FROM numbers(100);
SELECT groupArraySorted(100)(number) from numbers(1000);
SELECT groupArraySorted(10)(number) FROM numbers(5);
SELECT groupArraySorted(100)(number) FROM numbers(1000);
SELECT groupArraySorted(30)(str) FROM (SELECT toString(number) as str FROM numbers(30));