mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-24 00:22:29 +00:00
rewrite function to be independent of groupArray, add test
This commit is contained in:
parent
69205769d0
commit
cc8ac432dd
@ -19,24 +19,24 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
template <template <typename, typename> class AggregateFunctionTemplate, typename Data, typename ... TArgs>
|
||||
template <template <typename> class AggregateFunctionTemplate, typename ... TArgs>
|
||||
AggregateFunctionPtr createWithNumericOrTimeType(const IDataType & argument_type, TArgs && ... args)
|
||||
{
|
||||
WhichDataType which(argument_type);
|
||||
if (which.idx == TypeIndex::Date) return std::make_shared<AggregateFunctionTemplate<UInt16, Data>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::DateTime) return std::make_shared<AggregateFunctionTemplate<UInt32, Data>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::IPv4) return std::make_shared<AggregateFunctionTemplate<IPv4, Data>>(std::forward<TArgs>(args)...);
|
||||
return AggregateFunctionPtr(createWithNumericType<AggregateFunctionTemplate, Data, TArgs...>(argument_type, std::forward<TArgs>(args)...));
|
||||
if (which.idx == TypeIndex::Date) return std::make_shared<AggregateFunctionTemplate<UInt16>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::DateTime) return std::make_shared<AggregateFunctionTemplate<UInt32>>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::IPv4) return std::make_shared<AggregateFunctionTemplate<IPv4>>(std::forward<TArgs>(args)...);
|
||||
return AggregateFunctionPtr(createWithNumericType<AggregateFunctionTemplate, TArgs...>(argument_type, std::forward<TArgs>(args)...));
|
||||
}
|
||||
|
||||
template <typename Trait, typename ... TArgs>
|
||||
template <typename ... TArgs>
|
||||
inline AggregateFunctionPtr createAggregateFunctionGroupArraySortedImpl(const DataTypePtr & argument_type, const Array & parameters, TArgs ... args)
|
||||
{
|
||||
if (auto res = createWithNumericOrTimeType<GroupArraySortedNumericImpl, Trait>(*argument_type, argument_type, parameters, std::forward<TArgs>(args)...))
|
||||
if (auto res = createWithNumericOrTimeType<GroupArraySortedNumericImpl>(*argument_type, argument_type, parameters, std::forward<TArgs>(args)...))
|
||||
return AggregateFunctionPtr(res);
|
||||
|
||||
WhichDataType which(argument_type);
|
||||
return std::make_shared<GroupArraySortedGeneralImpl<GroupArrayNodeGeneral>>(argument_type, parameters, std::forward<TArgs>(args)...);
|
||||
return std::make_shared<GroupArraySortedGeneralImpl<GroupArraySortedNodeGeneral>>(argument_type, parameters, std::forward<TArgs>(args)...);
|
||||
}
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionGroupArraySorted(
|
||||
@ -66,7 +66,7 @@ AggregateFunctionPtr createAggregateFunctionGroupArraySorted(
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Function {} does not support this number of arguments", name);
|
||||
|
||||
return createAggregateFunctionGroupArraySortedImpl<GroupArrayTrait</* Thas_limit= */ true, false, /* Tsampler= */ Sampler::NONE>>(argument_types[0], parameters, max_elems);
|
||||
return createAggregateFunctionGroupArraySortedImpl(argument_types[0], parameters, max_elems);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -13,7 +13,6 @@
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <AggregateFunctions/AggregateFunctionGroupArray.h>
|
||||
#include <Functions/array/arraySort.h>
|
||||
|
||||
#include <Common/Exception.h>
|
||||
@ -43,10 +42,10 @@ namespace ErrorCodes
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct GroupArraySortedNumericData;
|
||||
struct GroupArraySortedData;
|
||||
|
||||
template <typename T>
|
||||
struct GroupArraySortedNumericData
|
||||
struct GroupArraySortedData
|
||||
{
|
||||
/// For easy serialization.
|
||||
static_assert(std::has_unique_object_representations_v<T> || std::is_floating_point_v<T>);
|
||||
@ -58,19 +57,18 @@ struct GroupArraySortedNumericData
|
||||
Array value;
|
||||
};
|
||||
|
||||
template <typename T, typename Trait>
|
||||
template <typename T>
|
||||
class GroupArraySortedNumericImpl final
|
||||
: public IAggregateFunctionDataHelper<GroupArraySortedNumericData<T>, GroupArraySortedNumericImpl<T, Trait>>
|
||||
: public IAggregateFunctionDataHelper<GroupArraySortedData<T>, GroupArraySortedNumericImpl<T>>
|
||||
{
|
||||
using Data = GroupArraySortedNumericData<T>;
|
||||
static constexpr bool limit_num_elems = Trait::has_limit;
|
||||
using Data = GroupArraySortedData<T>;
|
||||
UInt64 max_elems;
|
||||
SerializationPtr serialization;
|
||||
|
||||
public:
|
||||
explicit GroupArraySortedNumericImpl(
|
||||
const DataTypePtr & data_type_, const Array & parameters_, UInt64 max_elems_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<GroupArraySortedNumericData<T>, GroupArraySortedNumericImpl<T, Trait>>(
|
||||
: IAggregateFunctionDataHelper<GroupArraySortedData<T>, GroupArraySortedNumericImpl<T>>(
|
||||
{data_type_}, parameters_, std::make_shared<DataTypeArray>(data_type_))
|
||||
, max_elems(max_elems_)
|
||||
, serialization(data_type_->getDefaultSerialization())
|
||||
@ -105,12 +103,11 @@ public:
|
||||
|
||||
if (rhs_elems.value.size())
|
||||
cur_elems.value.insertByOffsets(rhs_elems.value, 0, rhs_elems.value.size(), arena);
|
||||
if (cur_elems.value.size() < max_elems)
|
||||
throw Exception(ErrorCodes::INCORRECT_DATA, "The max size of result array is bigger than the actual array size");
|
||||
|
||||
RadixSort<RadixSortNumTraits<T>>::executeLSD(cur_elems.value.data(), cur_elems.value.size());
|
||||
if (limit_num_elems)
|
||||
cur_elems.value.resize(max_elems, arena);
|
||||
|
||||
size_t elems_size = cur_elems.value.size() < max_elems ? cur_elems.value.size() : max_elems;
|
||||
cur_elems.value.resize(elems_size, arena);
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
|
||||
@ -148,11 +145,10 @@ public:
|
||||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
|
||||
{
|
||||
auto& value = this->data(place).value;
|
||||
if (value.size() < max_elems)
|
||||
throw Exception(ErrorCodes::INCORRECT_DATA, "The max size of result array is bigger than the actual array size");
|
||||
|
||||
RadixSort<RadixSortNumTraits<T>>::executeLSD(value.data(), value.size());
|
||||
if (limit_num_elems)
|
||||
value.resize(max_elems, arena);
|
||||
size_t elems_size = value.size() < max_elems ? value.size() : max_elems;
|
||||
value.resize(elems_size, arena);
|
||||
size_t size = value.size();
|
||||
|
||||
ColumnArray & arr_to = assert_cast<ColumnArray &>(to);
|
||||
@ -164,6 +160,8 @@ public:
|
||||
{
|
||||
typename ColumnVector<T>::Container & data_to = assert_cast<ColumnVector<T> &>(arr_to.getData()).getData();
|
||||
data_to.insert(this->data(place).value.begin(), this->data(place).value.end());
|
||||
RadixSort<RadixSortNumTraits<T>>::executeLSD(value.data(), value.size());
|
||||
value.resize(elems_size, arena);
|
||||
}
|
||||
}
|
||||
|
||||
@ -185,8 +183,27 @@ struct GroupArraySortedGeneralData<Node, false>
|
||||
};
|
||||
|
||||
template <typename Node>
|
||||
struct GroupArraySortedGeneralData<Node, true> : public GroupArraySamplerData<Node *>
|
||||
struct GroupArraySortedNodeBase
|
||||
{
|
||||
UInt64 size; // size of payload
|
||||
|
||||
/// Returns pointer to actual payload
|
||||
char * data() { return reinterpret_cast<char *>(this) + sizeof(Node); }
|
||||
|
||||
const char * data() const { return reinterpret_cast<const char *>(this) + sizeof(Node); }
|
||||
};
|
||||
|
||||
struct GroupArraySortedNodeString : public GroupArraySortedNodeBase<GroupArraySortedNodeString>
|
||||
{
|
||||
using Node = GroupArraySortedNodeString;
|
||||
|
||||
|
||||
};
|
||||
|
||||
struct GroupArraySortedNodeGeneral : public GroupArraySortedNodeBase<GroupArraySortedNodeGeneral>
|
||||
{
|
||||
using Node = GroupArraySortedNodeGeneral;
|
||||
|
||||
};
|
||||
|
||||
/// Implementation of groupArraySorted for Generic data via Array
|
||||
@ -303,7 +320,7 @@ public:
|
||||
|
||||
auto & column_data = column_array.getData();
|
||||
|
||||
if (std::is_same_v<Node, GroupArrayNodeString>)
|
||||
if (std::is_same_v<Node, GroupArraySortedNodeString>)
|
||||
{
|
||||
auto & string_offsets = assert_cast<ColumnString &>(column_data).getOffsets();
|
||||
string_offsets.reserve(string_offsets.size() + value.size());
|
||||
|
@ -1,4 +1,5 @@
|
||||
[0,1,2,3,4]
|
||||
[0,1,2,3,4]
|
||||
[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99]
|
||||
['0','1','10','11','12','13','14','15','16','17','18','19','2','20','21','22','23','24','25','26','27','28','29','3','4','5','6','7','8','9']
|
||||
[0,0,1,1,2,2,3,3,4,4]
|
||||
|
@ -1,6 +1,8 @@
|
||||
SELECT groupArraySorted(5)(number) from numbers(100);
|
||||
SELECT groupArraySorted(5)(number) FROM numbers(100);
|
||||
|
||||
SELECT groupArraySorted(100)(number) from numbers(1000);
|
||||
SELECT groupArraySorted(10)(number) FROM numbers(5);
|
||||
|
||||
SELECT groupArraySorted(100)(number) FROM numbers(1000);
|
||||
|
||||
SELECT groupArraySorted(30)(str) FROM (SELECT toString(number) as str FROM numbers(30));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user