#pragma once #include #include #include #include #include #include #include #include #include namespace DB { template struct AggregateFunctionTopKData { using Set = SpaceSaving < T, HashCRC32, HashTableGrower<4>, HashTableAllocatorWithStackMemory >; Set value; }; template class AggregateFunctionTopK : public IAggregateFunctionDataHelper, AggregateFunctionTopK> { protected: using State = AggregateFunctionTopKData; UInt64 threshold; UInt64 reserved; public: AggregateFunctionTopK(UInt64 threshold, UInt64 load_factor, const DataTypes & argument_types_, const Array & params) : IAggregateFunctionDataHelper, AggregateFunctionTopK>(argument_types_, params) , threshold(threshold), reserved(load_factor * threshold) {} String getName() const override { return is_weighted ? "topKWeighted" : "topK"; } DataTypePtr getReturnType() const override { return std::make_shared(std::make_shared>()); } void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override { auto & set = this->data(place).value; if (set.capacity() != reserved) set.resize(reserved); if constexpr (is_weighted) set.insert(static_cast &>(*columns[0]).getData()[row_num], columns[1]->getUInt(row_num)); else set.insert(static_cast &>(*columns[0]).getData()[row_num]); } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).value.merge(this->data(rhs).value); } void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { this->data(place).value.write(buf); } void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override { auto & set = this->data(place).value; set.resize(reserved); set.read(buf); } void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override { ColumnArray & arr_to = static_cast(to); ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); const typename State::Set & set = this->data(place).value; auto result_vec = set.topK(threshold); size_t size = result_vec.size(); offsets_to.push_back(offsets_to.back() + size); typename ColumnVector::Container & data_to = static_cast &>(arr_to.getData()).getData(); size_t old_size = data_to.size(); data_to.resize(old_size + size); size_t i = 0; for (auto it = result_vec.begin(); it != result_vec.end(); ++it, ++i) data_to[old_size + i] = it->key; } const char * getHeaderFilePath() const override { return __FILE__; } }; /// Generic implementation, it uses serialized representation as object descriptor. struct AggregateFunctionTopKGenericData { using Set = SpaceSaving < StringRef, StringRefHash, HashTableGrower<4>, HashTableAllocatorWithStackMemory >; Set value; }; /** Template parameter with true value should be used for columns that store their elements in memory continuously. * For such columns topK() can be implemented more efficiently (especially for small numeric arrays). */ template class AggregateFunctionTopKGeneric : public IAggregateFunctionDataHelper> { private: using State = AggregateFunctionTopKGenericData; UInt64 threshold; UInt64 reserved; DataTypePtr & input_data_type; static void deserializeAndInsert(StringRef str, IColumn & data_to); public: AggregateFunctionTopKGeneric( UInt64 threshold, UInt64 load_factor, const DataTypePtr & input_data_type, const Array & params) : IAggregateFunctionDataHelper>({input_data_type}, params) , threshold(threshold), reserved(load_factor * threshold), input_data_type(this->argument_types[0]) {} String getName() const override { return is_weighted ? "topKWeighted" : "topK"; } DataTypePtr getReturnType() const override { return std::make_shared(input_data_type); } bool allocatesMemoryInArena() const override { return true; } void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { this->data(place).value.write(buf); } void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override { auto & set = this->data(place).value; set.clear(); set.resize(reserved); // Specialized here because there's no deserialiser for StringRef size_t size = 0; readVarUInt(size, buf); for (size_t i = 0; i < size; ++i) { auto ref = readStringBinaryInto(*arena, buf); UInt64 count; UInt64 error; readVarUInt(count, buf); readVarUInt(error, buf); set.insert(ref, count, error); arena->rollback(ref.size); } set.readAlphaMap(buf); } void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override { auto & set = this->data(place).value; if (set.capacity() != reserved) set.resize(reserved); if constexpr (is_plain_column) { if constexpr (is_weighted) set.insert(columns[0]->getDataAt(row_num), columns[1]->getUInt(row_num)); else set.insert(columns[0]->getDataAt(row_num)); } else { const char * begin = nullptr; StringRef str_serialized = columns[0]->serializeValueIntoArena(row_num, *arena, begin); if constexpr (is_weighted) set.insert(str_serialized, columns[1]->getUInt(row_num)); else set.insert(str_serialized); arena->rollback(str_serialized.size); } } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override { this->data(place).value.merge(this->data(rhs).value); } void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override { ColumnArray & arr_to = static_cast(to); ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); IColumn & data_to = arr_to.getData(); auto result_vec = this->data(place).value.topK(threshold); offsets_to.push_back(offsets_to.back() + result_vec.size()); for (auto & elem : result_vec) { if constexpr (is_plain_column) data_to.insertData(elem.key.data, elem.key.size); else data_to.deserializeAndInsertFromArena(elem.key.data); } } const char * getHeaderFilePath() const override { return __FILE__; } }; }