diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.cpp b/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.cpp index d11be31a4da..d4801b723ab 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.cpp @@ -29,17 +29,23 @@ static IAggregateFunction * createWithNumericOrTimeType(const IDataType & argume } -template +template inline AggregateFunctionPtr createAggregateFunctionGroupArrayImpl(const DataTypePtr & argument_type, TArgs ... args) { - if (auto res = createWithNumericOrTimeType(*argument_type, argument_type, std::forward(args)...)) + if (auto res = createWithNumericOrTimeType(*argument_type, argument_type, std::forward(args)...)) return AggregateFunctionPtr(res); WhichDataType which(argument_type); if (which.idx == TypeIndex::String) - return std::make_shared>(argument_type, std::forward(args)...); + return std::make_shared>(argument_type, std::forward(args)...); - return std::make_shared>(argument_type, std::forward(args)...); + return std::make_shared>(argument_type, std::forward(args)...); + + // Link list implementation doesn't show noticeable performance improvement + // if (which.idx == TypeIndex::String) + // return std::make_shared>(argument_type, std::forward(args)...); + + // return std::make_shared>(argument_type, std::forward(args)...); } @@ -72,9 +78,38 @@ static AggregateFunctionPtr createAggregateFunctionGroupArray(const std::string ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); if (!limit_size) - return createAggregateFunctionGroupArrayImpl(argument_types[0]); + return createAggregateFunctionGroupArrayImpl>(argument_types[0]); else - return createAggregateFunctionGroupArrayImpl(argument_types[0], max_elems); + return createAggregateFunctionGroupArrayImpl>(argument_types[0], max_elems); +} + +static AggregateFunctionPtr +createAggregateFunctionGroupArraySample(const std::string & name, const DataTypes & argument_types, const Array & parameters) +{ + assertUnary(name, argument_types); + + UInt64 max_elems = std::numeric_limits::max(); + UInt64 seed = 123456; + + UInt64 * params[2] = {&max_elems, &seed}; + if (parameters.size() != 1 && parameters.size() != 2) + throw Exception("Incorrect number of parameters for aggregate function " + name + ", should be 1 or 2", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + for (auto i = 0ul; i < parameters.size(); ++i) + { + auto type = parameters[i].getType(); + if (type != Field::Types::Int64 && type != Field::Types::UInt64) + throw Exception("Parameter for aggregate function " + name + " should be positive number", ErrorCodes::BAD_ARGUMENTS); + + if ((type == Field::Types::Int64 && parameters[i].get() < 0) || + (type == Field::Types::UInt64 && parameters[i].get() == 0)) + throw Exception("Parameter for aggregate function " + name + " should be positive number", ErrorCodes::BAD_ARGUMENTS); + + *params[i] = parameters[i].get(); + } + + return createAggregateFunctionGroupArrayImpl>(argument_types[0], max_elems, seed); } } @@ -83,6 +118,7 @@ static AggregateFunctionPtr createAggregateFunctionGroupArray(const std::string void registerAggregateFunctionGroupArray(AggregateFunctionFactory & factory) { factory.registerFunction("groupArray", createAggregateFunctionGroupArray); + factory.registerFunction("groupArraySample", createAggregateFunctionGroupArraySample); } } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h b/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h index d58739e1dd8..432f315103a 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionGroupArray.h @@ -31,10 +31,38 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; } +enum class Sampler +{ + NONE, + RNG, + DETERMINATOR // TODO +}; + +template +struct GroupArrayTrait +{ + static constexpr bool has_limit = Thas_limit; + static constexpr Sampler sampler = Tsampler; +}; + +template +static constexpr const char * getNameByTrait() +{ + if (Trait::sampler == Sampler::NONE) + return "groupArray"; + else if (Trait::sampler == Sampler::RNG) + return "groupArraySample"; + // else if (Trait::sampler == Sampler::DETERMINATOR) // TODO + + __builtin_unreachable(); +} /// A particular case is an implementation for numeric types. +template +struct GroupArrayNumericData; + template -struct GroupArrayNumericData +struct GroupArrayNumericData { // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena using Allocator = MixedAlignedArenaAllocator; @@ -43,51 +71,162 @@ struct GroupArrayNumericData Array value; }; - -template -class GroupArrayNumericImpl final - : public IAggregateFunctionDataHelper, GroupArrayNumericImpl> +template +struct GroupArrayNumericData { - static constexpr bool limit_num_elems = Tlimit_num_elems::value; + // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena + using Allocator = MixedAlignedArenaAllocator; + using Array = PODArray; + + Array value; + size_t total_values = 0; + pcg32_fast rng; + + UInt64 genRandom(size_t lim) + { + /// With a large number of values, we will generate random numbers several times slower. + if (lim <= static_cast(rng.max())) + return static_cast(rng()) % static_cast(lim); + else + return (static_cast(rng()) * (static_cast(rng.max()) + 1ULL) + static_cast(rng())) % lim; + } + + void randomShuffle() + { + for (size_t i = 1; i < value.size(); ++i) + { + size_t j = genRandom(i + 1); + std::swap(value[i], value[j]); + } + } +}; + +template +class GroupArrayNumericImpl final + : public IAggregateFunctionDataHelper, GroupArrayNumericImpl> +{ + using Data = GroupArrayNumericData; + static constexpr bool limit_num_elems = Trait::has_limit; DataTypePtr & data_type; UInt64 max_elems; + UInt64 seed; public: - explicit GroupArrayNumericImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits::max()) - : IAggregateFunctionDataHelper, GroupArrayNumericImpl>({data_type_}, {}) - , data_type(this->argument_types[0]), max_elems(max_elems_) {} + explicit GroupArrayNumericImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits::max(), UInt64 seed_ = 123456) + : IAggregateFunctionDataHelper, GroupArrayNumericImpl>( + {data_type_}, {}) + , data_type(this->argument_types[0]) + , max_elems(max_elems_) + , seed(seed_) + { + } - String getName() const override { return "groupArray"; } + String getName() const override { return getNameByTrait(); } DataTypePtr getReturnType() const override { return std::make_shared(data_type); } + void insert(Data & a, const T & v, Arena * arena) const + { + ++a.total_values; + if (a.value.size() < max_elems) + a.value.push_back(v, arena); + else + { + UInt64 rnd = a.genRandom(a.total_values); + if (rnd < max_elems) + a.value[rnd] = v; + } + } + void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override { - if (limit_num_elems && this->data(place).value.size() >= max_elems) - return; + if constexpr (Trait::sampler == Sampler::NONE) + { + if (limit_num_elems && this->data(place).value.size() >= max_elems) + return; - this->data(place).value.push_back(assert_cast &>(*columns[0]).getData()[row_num], arena); + this->data(place).value.push_back(assert_cast &>(*columns[0]).getData()[row_num], arena); + } + + if constexpr (Trait::sampler == Sampler::RNG) + { + auto & a = this->data(place); + ++a.total_values; + if (a.value.empty()) + a.rng.seed(seed); + if (a.value.size() < max_elems) + a.value.push_back(assert_cast &>(*columns[0]).getData()[row_num], arena); + else + { + UInt64 rnd = a.genRandom(a.total_values); + if (rnd < max_elems) + a.value[rnd] = assert_cast &>(*columns[0]).getData()[row_num]; + } + } + // TODO + // if constexpr (Trait::sampler == Sampler::DETERMINATOR) } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override { - auto & cur_elems = this->data(place); - auto & rhs_elems = this->data(rhs); + if constexpr (Trait::sampler == Sampler::NONE) + { + auto & cur_elems = this->data(place); + auto & rhs_elems = this->data(rhs); - if (!limit_num_elems) - { - if (rhs_elems.value.size()) - cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.end(), arena); + if (!limit_num_elems) + { + if (rhs_elems.value.size()) + cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.end(), arena); + } + else + { + UInt64 elems_to_insert = std::min(static_cast(max_elems) - cur_elems.value.size(), rhs_elems.value.size()); + if (elems_to_insert) + cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.begin() + elems_to_insert, arena); + } } - else + + if constexpr (Trait::sampler == Sampler::RNG) { - UInt64 elems_to_insert = std::min(static_cast(max_elems) - cur_elems.value.size(), rhs_elems.value.size()); - if (elems_to_insert) - cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.begin() + elems_to_insert, arena); + if (this->data(rhs).value.empty()) /// rhs state is empty + return; + + auto & a = this->data(place); + auto & b = this->data(rhs); + + if (b.total_values <= max_elems) + { + for (size_t i = 0; i < b.value.size(); ++i) + insert(a, b.value[i], arena); + } + else if (a.total_values <= max_elems) + { + decltype(a.value) from; + from.swap(a.value, arena); + a.value.assign(b.value.begin(), b.value.end(), arena); + a.total_values = b.total_values; + for (size_t i = 0; i < from.size(); ++i) + insert(a, from[i], arena); + } + else + { + a.randomShuffle(); + a.total_values += b.total_values; + for (size_t i = 0; i < max_elems; ++i) + { + UInt64 rnd = a.genRandom(a.total_values); + if (rnd < b.total_values) + a.value[i] = b.value[i]; + } + } } + + // TODO + // if constexpr (Trait::sampler == Sampler::DETERMINATOR) } void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override @@ -96,6 +235,17 @@ public: size_t size = value.size(); writeVarUInt(size, buf); buf.write(reinterpret_cast(value.data()), size * sizeof(value[0])); + + if constexpr (Trait::sampler == Sampler::RNG) + { + DB::writeIntBinary(this->data(place).total_values, buf); + std::ostringstream rng_stream; + rng_stream << this->data(place).rng; + DB::writeStringBinary(rng_stream.str(), buf); + } + + // TODO + // if constexpr (Trait::sampler == Sampler::DETERMINATOR) } void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override @@ -113,6 +263,18 @@ public: value.resize(size, arena); buf.read(reinterpret_cast(value.data()), size * sizeof(value[0])); + + if constexpr (Trait::sampler == Sampler::RNG) + { + DB::readIntBinary(this->data(place).total_values, buf); + std::string rng_string; + DB::readStringBinary(rng_string, buf); + std::istringstream rng_stream(rng_string); + rng_stream >> this->data(place).rng; + } + + // TODO + // if constexpr (Trait::sampler == Sampler::DETERMINATOR) } void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override @@ -145,26 +307,30 @@ public: /// Nodes used to implement a linked list for storage of groupArray states template -struct GroupArrayListNodeBase +struct GroupArrayNodeBase { - Node * next; UInt64 size; // size of payload /// Returns pointer to actual payload char * data() { - static_assert(sizeof(GroupArrayListNodeBase) == sizeof(Node)); return reinterpret_cast(this) + sizeof(Node); } - /// Clones existing node (does not modify next field) - Node * clone(Arena * arena) + const char * data() const { - return reinterpret_cast(const_cast(arena->alignedInsert(reinterpret_cast(this), sizeof(Node) + size, alignof(Node)))); + return reinterpret_cast(this) + sizeof(Node); + } + + /// Clones existing node (does not modify next field) + Node * clone(Arena * arena) const + { + return reinterpret_cast( + const_cast(arena->alignedInsert(reinterpret_cast(this), sizeof(Node) + size, alignof(Node)))); } /// Write node to buffer - void write(WriteBuffer & buf) + void write(WriteBuffer & buf) const { writeVarUInt(size, buf); buf.write(data(), size); @@ -183,6 +349,343 @@ struct GroupArrayListNodeBase } }; +struct GroupArrayNodeString : public GroupArrayNodeBase +{ + using Node = GroupArrayNodeString; + + /// Create node from string + static Node * allocate(const IColumn & column, size_t row_num, Arena * arena) + { + StringRef string = assert_cast(column).getDataAt(row_num); + + Node * node = reinterpret_cast(arena->alignedAlloc(sizeof(Node) + string.size, alignof(Node))); + node->size = string.size; + memcpy(node->data(), string.data, string.size); + + return node; + } + + void insertInto(IColumn & column) + { + assert_cast(column).insertData(data(), size); + } +}; + +struct GroupArrayNodeGeneral : public GroupArrayNodeBase +{ + using Node = GroupArrayNodeGeneral; + + static Node * allocate(const IColumn & column, size_t row_num, Arena * arena) + { + const char * begin = arena->alignedAlloc(sizeof(Node), alignof(Node)); + StringRef value = column.serializeValueIntoArena(row_num, *arena, begin); + + Node * node = reinterpret_cast(const_cast(begin)); + node->size = value.size; + + return node; + } + + void insertInto(IColumn & column) + { + column.deserializeAndInsertFromArena(data()); + } +}; + +class MyAllocator : protected Allocator +{ + using Base = Allocator; +public: + void * alloc(size_t size, Arena *) + { + return Base::alloc(size, 8); + } + + void free(void * buf, size_t size) + { + Base::free(buf, size); + } + + void * realloc(void * buf, size_t old_size, size_t new_size, Arena *) + { + return Base::realloc(buf, old_size, new_size, 8); + } +}; + + +template +struct GroupArrayGeneralData; + +template +struct GroupArrayGeneralData +{ + // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena + using Allocator = MixedAlignedArenaAllocator; + // using Allocator = MyAllocator; + using Array = PODArray; + + Array value; +}; + +template +struct GroupArrayGeneralData +{ + // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena + using Allocator = MixedAlignedArenaAllocator; + // using Allocator = MyAllocator; + using Array = PODArray; + + Array value; + size_t total_values = 0; + pcg32_fast rng; + + UInt64 genRandom(size_t lim) + { + /// With a large number of values, we will generate random numbers several times slower. + if (lim <= static_cast(rng.max())) + return static_cast(rng()) % static_cast(lim); + else + return (static_cast(rng()) * (static_cast(rng.max()) + 1ULL) + static_cast(rng())) % lim; + } + + void randomShuffle() + { + for (size_t i = 1; i < value.size(); ++i) + { + size_t j = genRandom(i + 1); + std::swap(value[i], value[j]); + } + } +}; + +/// Implementation of groupArray for String or any ComplexObject via Array +template +class GroupArrayGeneralImpl final : public IAggregateFunctionDataHelper< + GroupArrayGeneralData, + GroupArrayGeneralImpl> +{ + static constexpr bool limit_num_elems = Trait::has_limit; + using Data = GroupArrayGeneralData; + static Data & data(AggregateDataPtr place) { return *reinterpret_cast(place); } + static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast(place); } + + DataTypePtr & data_type; + UInt64 max_elems; + UInt64 seed; + +public: + GroupArrayGeneralImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits::max(), UInt64 seed_ = 123456) + : IAggregateFunctionDataHelper< + GroupArrayGeneralData, + GroupArrayGeneralImpl>({data_type_}, {}) + , data_type(this->argument_types[0]) + , max_elems(max_elems_) + , seed(seed_) + { + } + + String getName() const override { return getNameByTrait(); } + + DataTypePtr getReturnType() const override { return std::make_shared(data_type); } + + void insert(Data & a, const Node * v, Arena * arena) const + { + ++a.total_values; + if (a.value.size() < max_elems) + a.value.push_back(v->clone(arena), arena); + else + { + UInt64 rnd = a.genRandom(a.total_values); + if (rnd < max_elems) + a.value[rnd] = v->clone(arena); + } + } + + void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + if constexpr (Trait::sampler == Sampler::NONE) + { + if (limit_num_elems && data(place).value.size() >= max_elems) + return; + + Node * node = Node::allocate(*columns[0], row_num, arena); + data(place).value.push_back(node, arena); + } + + if constexpr (Trait::sampler == Sampler::RNG) + { + auto & a = data(place); + ++a.total_values; + if (a.value.empty()) + a.rng.seed(seed); + if (a.value.size() < max_elems) + a.value.push_back(Node::allocate(*columns[0], row_num, arena), arena); + else + { + UInt64 rnd = a.genRandom(a.total_values); + if (rnd < max_elems) + a.value[rnd] = Node::allocate(*columns[0], row_num, arena); + } + } + // TODO + // if constexpr (Trait::sampler == Sampler::DETERMINATOR) + } + + void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override + { + if constexpr (Trait::sampler == Sampler::NONE) + mergeNoSampler(place, rhs, arena); + else if constexpr (Trait::sampler == Sampler::RNG) + mergeWithRNGSampler(place, rhs, arena); + // TODO + // else if constexpr (Trait::sampler == Sampler::DETERMINATOR) + } + + void ALWAYS_INLINE mergeNoSampler(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const + { + if (data(rhs).value.empty()) /// rhs state is empty + return; + + UInt64 new_elems; + if (limit_num_elems) + { + if (data(place).value.size() >= max_elems) + return; + + new_elems = std::min(data(rhs).value.size(), max_elems - data(place).value.size()); + } + else + new_elems = data(rhs).value.size(); + + auto & a = data(place).value; + auto & b = data(rhs).value; + for (UInt64 i = 0; i < new_elems; ++i) + a.push_back(b[i]->clone(arena), arena); + } + + void ALWAYS_INLINE mergeWithRNGSampler(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const + { + if (data(rhs).value.empty()) /// rhs state is empty + return; + + auto & a = data(place); + auto & b = data(rhs); + + if (b.total_values <= max_elems) + { + for (size_t i = 0; i < b.value.size(); ++i) + insert(a, b.value[i], arena); + } + else if (a.total_values <= max_elems) + { + decltype(a.value) from; + from.swap(a.value, arena); + for (auto & node : b.value) + a.value.push_back(node->clone(arena), arena); + a.total_values = b.total_values; + for (size_t i = 0; i < from.size(); ++i) + insert(a, from[i], arena); + } + else + { + a.randomShuffle(); + a.total_values += b.total_values; + for (size_t i = 0; i < max_elems; ++i) + { + UInt64 rnd = a.genRandom(a.total_values); + if (rnd < b.total_values) + a.value[i] = b.value[i]->clone(arena); + } + } + } + + void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override + { + writeVarUInt(data(place).value.size(), buf); + + auto & value = data(place).value; + for (auto & node : value) + node->write(buf); + + if constexpr (Trait::sampler == Sampler::RNG) + { + DB::writeIntBinary(data(place).total_values, buf); + std::ostringstream rng_stream; + rng_stream << data(place).rng; + DB::writeStringBinary(rng_stream.str(), buf); + } + + // TODO + // if constexpr (Trait::sampler == Sampler::DETERMINATOR) + } + + void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override + { + UInt64 elems; + readVarUInt(elems, buf); + + if (unlikely(elems == 0)) + return; + + if (unlikely(elems > AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE)) + throw Exception("Too large array size", ErrorCodes::TOO_LARGE_ARRAY_SIZE); + + if (limit_num_elems && unlikely(elems > max_elems)) + throw Exception("Too large array size, it should not exceed " + toString(max_elems), ErrorCodes::TOO_LARGE_ARRAY_SIZE); + + auto & value = data(place).value; + + value.resize(elems, arena); + for (UInt64 i = 0; i < elems; ++i) + value[i] = Node::read(buf, arena); + + if constexpr (Trait::sampler == Sampler::RNG) + { + DB::readIntBinary(data(place).total_values, buf); + std::string rng_string; + DB::readStringBinary(rng_string, buf); + std::istringstream rng_stream(rng_string); + rng_stream >> data(place).rng; + } + + // TODO + // if constexpr (Trait::sampler == Sampler::DETERMINATOR) + } + + void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override + { + auto & column_array = assert_cast(to); + + auto & offsets = column_array.getOffsets(); + offsets.push_back(offsets.back() + data(place).value.size()); + + auto & column_data = column_array.getData(); + + if (std::is_same_v) + { + auto & string_offsets = assert_cast(column_data).getOffsets(); + string_offsets.reserve(string_offsets.size() + data(place).value.size()); + } + + auto & value = data(place).value; + for (auto & node : value) + node->insertInto(column_data); + } + + bool allocatesMemoryInArena() const override + { + return true; + } + + const char * getHeaderFilePath() const override { return __FILE__; } +}; + +template +struct GroupArrayListNodeBase : public GroupArrayNodeBase +{ + Node * next; +}; + struct GroupArrayListNodeString : public GroupArrayListNodeBase { using Node = GroupArrayListNodeString; @@ -240,10 +743,12 @@ struct GroupArrayGeneralListData /// Implementation of groupArray for String or any ComplexObject via linked list /// It has poor performance in case of many small objects -template -class GroupArrayGeneralListImpl final - : public IAggregateFunctionDataHelper, GroupArrayGeneralListImpl> +template +class GroupArrayGeneralListImpl final : public IAggregateFunctionDataHelper< + GroupArrayGeneralListData, + GroupArrayGeneralListImpl> { + static constexpr bool limit_num_elems = Trait::has_limit; using Data = GroupArrayGeneralListData; static Data & data(AggregateDataPtr place) { return *reinterpret_cast(place); } static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast(place); } @@ -253,10 +758,15 @@ class GroupArrayGeneralListImpl final public: GroupArrayGeneralListImpl(const DataTypePtr & data_type_, UInt64 max_elems_ = std::numeric_limits::max()) - : IAggregateFunctionDataHelper, GroupArrayGeneralListImpl>({data_type_}, {}) - , data_type(this->argument_types[0]), max_elems(max_elems_) {} + : IAggregateFunctionDataHelper< + GroupArrayGeneralListData, + GroupArrayGeneralListImpl>({data_type_}, {}) + , data_type(this->argument_types[0]) + , max_elems(max_elems_) + { + } - String getName() const override { return "groupArray"; } + String getName() const override { return getNameByTrait(); } DataTypePtr getReturnType() const override { return std::make_shared(data_type); } diff --git a/dbms/src/Common/ArenaAllocator.h b/dbms/src/Common/ArenaAllocator.h index 6eb415d5e54..948bb4ddd0a 100644 --- a/dbms/src/Common/ArenaAllocator.h +++ b/dbms/src/Common/ArenaAllocator.h @@ -35,6 +35,12 @@ public: { // Do nothing, trash in arena remains. } + +protected: + static constexpr size_t getStackThreshold() + { + return 0; + } }; @@ -66,6 +72,12 @@ public: static void free(void * /*buf*/, size_t /*size*/) { } + +protected: + static constexpr size_t getStackThreshold() + { + return 0; + } }; @@ -100,6 +112,12 @@ public: if (size >= REAL_ALLOCATION_TRESHOLD) TRealAllocator::free(buf, size); } + +protected: + static constexpr size_t getStackThreshold() + { + return 0; + } }; @@ -136,6 +154,12 @@ public: } void free(void * /*buf*/, size_t /*size*/) {} + +protected: + static constexpr size_t getStackThreshold() + { + return N; + } }; } diff --git a/dbms/src/Common/PODArray.h b/dbms/src/Common/PODArray.h index 441befd2d5f..e9c85792e40 100644 --- a/dbms/src/Common/PODArray.h +++ b/dbms/src/Common/PODArray.h @@ -150,7 +150,7 @@ protected: bool isAllocatedFromStack() const { - constexpr size_t stack_threshold = TAllocator::getStackThreshold(); + static constexpr size_t stack_threshold = TAllocator::getStackThreshold(); return (stack_threshold > 0) && (allocated_bytes() <= stack_threshold); } @@ -453,7 +453,8 @@ public: this->c_end += bytes_to_copy; } - void swap(PODArray & rhs) + template + void swap(PODArray & rhs, TAllocatorParams &&... allocator_params) { #ifndef NDEBUG this->unprotect(); @@ -463,7 +464,7 @@ public: /// Swap two PODArray objects, arr1 and arr2, that satisfy the following conditions: /// - The elements of arr1 are stored on stack. /// - The elements of arr2 are stored on heap. - auto swap_stack_heap = [this](PODArray & arr1, PODArray & arr2) + auto swap_stack_heap = [&](PODArray & arr1, PODArray & arr2) { size_t stack_size = arr1.size(); size_t stack_allocated = arr1.allocated_bytes(); @@ -480,18 +481,18 @@ public: arr1.c_end = arr1.c_start + this->byte_size(heap_size); /// Allocate stack space for arr2. - arr2.alloc(stack_allocated); + arr2.alloc(stack_allocated, std::forward(allocator_params)...); /// Copy the stack content. memcpy(arr2.c_start, stack_c_start, this->byte_size(stack_size)); arr2.c_end = arr2.c_start + this->byte_size(stack_size); }; - auto do_move = [this](PODArray & src, PODArray & dest) + auto do_move = [&](PODArray & src, PODArray & dest) { if (src.isAllocatedFromStack()) { dest.dealloc(); - dest.alloc(src.allocated_bytes()); + dest.alloc(src.allocated_bytes(), std::forward(allocator_params)...); memcpy(dest.c_start, src.c_start, this->byte_size(src.size())); dest.c_end = dest.c_start + (src.c_end - src.c_start); @@ -569,24 +570,26 @@ public: } } - void assign(size_t n, const T & x) + template + void assign(size_t n, const T & x, TAllocatorParams &&... allocator_params) { - this->resize(n); + this->resize(n, std::forward(allocator_params)...); std::fill(begin(), end(), x); } - template - void assign(It1 from_begin, It2 from_end) + template + void assign(It1 from_begin, It2 from_end, TAllocatorParams &&... allocator_params) { size_t required_capacity = from_end - from_begin; if (required_capacity > this->capacity()) - this->reserve(roundUpToPowerOfTwoOrZero(required_capacity)); + this->reserve(roundUpToPowerOfTwoOrZero(required_capacity), std::forward(allocator_params)...); size_t bytes_to_copy = this->byte_size(required_capacity); memcpy(this->c_start, reinterpret_cast(&*from_begin), bytes_to_copy); this->c_end = this->c_start + bytes_to_copy; } + // ISO C++ has strict ambiguity rules, thus we cannot apply TAllocatorParams here. void assign(const PODArray & from) { assign(from.begin(), from.end()); diff --git a/dbms/tests/queries/0_stateless/01050_group_array_sample.reference b/dbms/tests/queries/0_stateless/01050_group_array_sample.reference new file mode 100644 index 00000000000..3a513512f93 --- /dev/null +++ b/dbms/tests/queries/0_stateless/01050_group_array_sample.reference @@ -0,0 +1,8 @@ +0 [576,800,64,936,552,216,252,808,920,780] +1 [577,801,65,937,553,217,253,809,921,781] +2 [578,802,66,938,554,218,254,810,922,782] +3 [579,803,67,939,555,219,255,811,923,783] +0 [128,184,304,140,568,528,772,452,176,648] +1 [129,185,305,141,569,529,773,453,177,649] +2 [130,186,306,142,570,530,774,454,178,650] +3 [131,187,307,143,571,531,775,455,179,651] diff --git a/dbms/tests/queries/0_stateless/01050_group_array_sample.sql b/dbms/tests/queries/0_stateless/01050_group_array_sample.sql new file mode 100644 index 00000000000..395ab9d41b6 --- /dev/null +++ b/dbms/tests/queries/0_stateless/01050_group_array_sample.sql @@ -0,0 +1,4 @@ +select k, groupArraySample(10)(v) from (select number % 4 as k, number as v from numbers(1024)) group by k; + +-- different seed +select k, groupArraySample(10, 1)(v) from (select number % 4 as k, number as v from numbers(1024)) group by k;