diff --git a/dbms/src/Columns/ColumnWithDictionary.cpp b/dbms/src/Columns/ColumnWithDictionary.cpp index 59bca74ebad..2a0e211d1b2 100644 --- a/dbms/src/Columns/ColumnWithDictionary.cpp +++ b/dbms/src/Columns/ColumnWithDictionary.cpp @@ -356,6 +356,17 @@ ColumnWithDictionary::getMinimalDictionaryEncodedColumn(size_t offset, size_t li return {std::move(sub_keys), std::move(sub_indexes)}; } +ColumnPtr ColumnWithDictionary::countKeys() const +{ + const auto & nested_column = getDictionary().getNestedColumn(); + size_t dict_size = nested_column->size(); + + auto counter = ColumnUInt64::create(dict_size, 0); + idx.countKeys(counter->getData()); + return std::move(counter); +} + + ColumnWithDictionary::Index::Index() : positions(ColumnUInt8::create()), size_of_type(sizeof(UInt8)) {} @@ -430,6 +441,18 @@ typename ColumnVector::Container & ColumnWithDictionary::Index::getPo return positions_ptr->getData(); } +template +typename const ColumnVector::Container & ColumnWithDictionary::Index::getPositionsData() const +{ + const auto * positions_ptr = typeid_cast *>(positions.get()); + if (!positions_ptr) + throw Exception("Invalid indexes type for ColumnWithDictionary." + " Expected UInt" + toString(8 * sizeof(IndexType)) + ", got " + positions->getName(), + ErrorCodes::LOGICAL_ERROR); + + return positions_ptr->getData(); +} + template void ColumnWithDictionary::Index::convertPositions() { @@ -616,4 +639,17 @@ void ColumnWithDictionary::Dictionary::compact(ColumnPtr & positions) shared = false; } + +void ColumnWithDictionary::Dictionary::countKeys(ColumnUInt64::Container & counts) const +{ + auto counter = [&](auto x) + { + using CurIndexType = decltype(x); + auto & data = getPositionsData(); + for (auto pos : data) + ++counts[pos]; + }; + callForType(std::move(counter), size_of_type); +} + } diff --git a/dbms/src/Columns/ColumnWithDictionary.h b/dbms/src/Columns/ColumnWithDictionary.h index 5d68dca5796..84a13fb5b66 100644 --- a/dbms/src/Columns/ColumnWithDictionary.h +++ b/dbms/src/Columns/ColumnWithDictionary.h @@ -164,6 +164,8 @@ public: DictionaryEncodedColumn getMinimalDictionaryEncodedColumn(size_t offset, size_t limit) const; + ColumnPtr countKeys() const; + class Index { public: @@ -190,6 +192,8 @@ public: ColumnPtr detachPositions() { return std::move(positions); } void attachPositions(ColumnPtr positions_); + void countKeys(ColumnUInt64::Container & counts) const; + private: ColumnPtr positions; size_t size_of_type = 0; @@ -200,6 +204,9 @@ public: template typename ColumnVector::Container & getPositionsData(); + template + typename const ColumnVector::Container & getPositionsData() const; + template void convertPositions(); diff --git a/dbms/src/Interpreters/Aggregator.cpp b/dbms/src/Interpreters/Aggregator.cpp index bf629903e09..8257c6a77e1 100644 --- a/dbms/src/Interpreters/Aggregator.cpp +++ b/dbms/src/Interpreters/Aggregator.cpp @@ -25,6 +25,8 @@ #include #if __has_include() #include +#include + #endif @@ -534,6 +536,28 @@ void NO_INLINE Aggregator::executeImpl( executeImplCase(method, state, aggregates_pool, rows, key_columns, aggregate_instructions, key_sizes, keys, overflow_row); } +template +void NO_INLINE Aggregator::executeLowCardinalityImpl( + Method & method, + Arena * aggregates_pool, + size_t rows, + ColumnRawPtrs & key_columns, + ColumnRawPtrs & key_counts, + AggregateFunctionInstruction * aggregate_instructions, + const Sizes & key_sizes, + StringRefs & keys, + bool no_more_keys, + AggregateDataPtr overflow_row) const +{ + typename Method::State state; + state.init(key_columns); + + if (!no_more_keys) + executeLowCardinalityImplCase(method, state, aggregates_pool, rows, key_columns, key_counts, aggregate_instructions, key_sizes, keys, overflow_row); + else + executeLowCardinalityImplCase(method, state, aggregates_pool, rows, key_columns, key_counts, aggregate_instructions, key_sizes, keys, overflow_row); +} + #ifndef __clang__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" @@ -626,6 +650,100 @@ void NO_INLINE Aggregator::executeImplCase( } } +template +void NO_INLINE Aggregator::executeLowCardinalityImplCase( + Method & method, + typename Method::State & state, + Arena * aggregates_pool, + size_t rows, + ColumnRawPtrs & key_columns, + ColumnRawPtrs & key_counts, + AggregateFunctionInstruction * aggregate_instructions, + const Sizes & key_sizes, + StringRefs & keys, + AggregateDataPtr overflow_row) const +{ + /// NOTE When editing this code, also pay attention to SpecializedAggregator.h. + + auto & counts_data = static_cast(key_counts[0])->getData(); + + /// For all rows. + typename Method::iterator it; + typename Method::Key prev_key; + for (size_t i = 0; i < rows; ++i) + { + /// Get the key to insert into the hash table. + typename Method::Key key = state.getKey(key_columns, params.keys_size, i, key_sizes, keys, *aggregates_pool); + + size_t num_repeats = counts_data[i]; + for (size_t repeat = 0; repeat < num_repeats; ++repeat) + { + bool inserted; /// Inserted a new key, or was this key already? + bool overflow = false; /// The new key did not fit in the hash table because of no_more_keys. + + if (!no_more_keys) /// Insert. + { + /// Optimization for consecutive identical keys. + if (!Method::no_consecutive_keys_optimization) + { + if (i != 0 && (repeat || key == prev_key)) + { + /// Add values to the aggregate functions. + AggregateDataPtr value = Method::getAggregateData(it->second); + for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst) + (*inst->func)(inst->that, value + inst->state_offset, inst->arguments, i, aggregates_pool); + + method.onExistingKey(key, keys, *aggregates_pool); + continue; + } + else + prev_key = key; + } + + method.data.emplace(key, it, inserted); + } + else + { + /// Add only if the key already exists. + inserted = false; + it = method.data.find(key); + if (method.data.end() == it) + overflow = true; + } + + /// If the key does not fit, and the data does not need to be aggregated in a separate row, then there's nothing to do. + if (no_more_keys && overflow && !overflow_row) + { + method.onExistingKey(key, keys, *aggregates_pool); + continue; + } + + /// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key. + if (inserted) + { + AggregateDataPtr & aggregate_data = Method::getAggregateData(it->second); + + /// exception-safety - if you can not allocate memory or create states, then destructors will not be called. + aggregate_data = nullptr; + + method.onNewKey(*it, params.keys_size, keys, *aggregates_pool); + + AggregateDataPtr place = aggregates_pool->alloc(total_size_of_aggregate_states); + createAggregateStates(place); + aggregate_data = place; + } + else + method.onExistingKey(key, keys, *aggregates_pool); + + AggregateDataPtr value = (!no_more_keys || !overflow) ? Method::getAggregateData(it->second) : overflow_row; + + /// Add values to the aggregate functions. + for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst) + (*inst->func)(inst->that, value + inst->state_offset, inst->arguments, i, aggregates_pool); + } + } +} + #ifndef __clang__ #pragma GCC diagnostic pop #endif @@ -672,6 +790,7 @@ bool Aggregator::executeOnBlock(const Block & block, AggregatedDataVariants & re * To make them work anyway, we materialize them. */ Columns materialized_columns; + ColumnRawPtrs key_counts; /// Remember the columns we will work with for (size_t i = 0; i < params.keys_size; ++i) @@ -683,6 +802,20 @@ bool Aggregator::executeOnBlock(const Block & block, AggregatedDataVariants & re materialized_columns.push_back(converted); key_columns[i] = materialized_columns.back().get(); } + + if (const auto * column_with_dictionary = typeid_cast(key_columns[i])) + { + if (params.keys_size == 1) + { + materialized_columns.push_back(column_with_dictionary->countKeys()); + key_columns[i] = column_with_dictionary->getDictionary().getNestedColumn().get(); + } + else + { + materialized_columns.push_back(column_with_dictionary->convertToFullColumn()); + } + key_counts.push_back(materialized_columns.back().get()); + } } AggregateFunctionInstructions aggregate_functions_instructions(params.aggregates_size + 1); diff --git a/dbms/src/Interpreters/Aggregator.h b/dbms/src/Interpreters/Aggregator.h index 5919a296085..68170a28025 100644 --- a/dbms/src/Interpreters/Aggregator.h +++ b/dbms/src/Interpreters/Aggregator.h @@ -1215,6 +1215,19 @@ protected: bool no_more_keys, AggregateDataPtr overflow_row) const; + template + void executeLowCardinalityImpl( + Method & method, + Arena * aggregates_pool, + size_t rows, + ColumnRawPtrs & key_columns, + ColumnRawPtrs & key_counts, + AggregateFunctionInstruction * aggregate_instructions, + const Sizes & key_sizes, + StringRefs & keys, + bool no_more_keys, + AggregateDataPtr overflow_row) const; + /// Specialization for a particular value no_more_keys. template void executeImplCase( @@ -1228,6 +1241,19 @@ protected: StringRefs & keys, AggregateDataPtr overflow_row) const; + template + void executeLowCardinalityImplCase( + Method & method, + typename Method::State & state, + Arena * aggregates_pool, + size_t rows, + ColumnRawPtrs & key_columns, + ColumnRawPtrs & key_counts, + AggregateFunctionInstruction * aggregate_instructions, + const Sizes & key_sizes, + StringRefs & keys, + AggregateDataPtr overflow_row) const; + /// For case when there are no keys (all aggregate into one row). void executeWithoutKeyImpl( AggregatedDataWithoutKey & res,