diff --git a/dbms/src/Columns/ColumnAggregateFunction.cpp b/dbms/src/Columns/ColumnAggregateFunction.cpp index fb59ad90ec8..38336562dce 100644 --- a/dbms/src/Columns/ColumnAggregateFunction.cpp +++ b/dbms/src/Columns/ColumnAggregateFunction.cpp @@ -284,7 +284,7 @@ void ColumnAggregateFunction::insert(const Field & x) Arena & arena = createOrGetArena(); - getData().push_back(arena.alloc(function->sizeOfData())); + getData().push_back(arena.alignedAlloc(function->sizeOfData(), function->alignOfData())); function->create(getData().back()); ReadBufferFromString read_buffer(x.get()); function->deserialize(getData().back(), read_buffer, &arena); @@ -296,7 +296,7 @@ void ColumnAggregateFunction::insertDefault() Arena & arena = createOrGetArena(); - getData().push_back(arena.alloc(function->sizeOfData())); + getData().push_back(arena.alignedAlloc(function->sizeOfData(), function->alignOfData())); function->create(getData().back()); } @@ -317,7 +317,7 @@ const char * ColumnAggregateFunction::deserializeAndInsertFromArena(const char * */ Arena & dst_arena = createOrGetArena(); - getData().push_back(dst_arena.alloc(function->sizeOfData())); + getData().push_back(dst_arena.alignedAlloc(function->sizeOfData(), function->alignOfData())); function->create(getData().back()); /** We will read from src_arena. diff --git a/dbms/src/Common/Arena.h b/dbms/src/Common/Arena.h index 7538b9a71a6..6f36d352cab 100644 --- a/dbms/src/Common/Arena.h +++ b/dbms/src/Common/Arena.h @@ -124,6 +124,29 @@ public: return res; } + /// Get peice of memory with alignment + char * alignedAlloc(size_t size, size_t align) + { + // Fast code path for non-alignment requirement + if (align <= 1) + return alloc(size); + + uintptr_t pos = reinterpret_cast(head->pos); + // next pos match align requirement + pos = ((pos-1) & ~(align - 1)) + align; + + if (unlikely(pos + size > reinterpret_cast(head->end))) + { + addChunk(size); + pos = reinterpret_cast(head->pos); + pos = ((pos-1) & ~(align - 1)) + align; + } + char * res = (char *)pos; + head->pos = res + size; + + return res; + } + /** Rollback just performed allocation. * Must pass size not more that was just allocated. */ diff --git a/dbms/src/DataTypes/DataTypeAggregateFunction.cpp b/dbms/src/DataTypes/DataTypeAggregateFunction.cpp index f005f2e2eea..f4384ce321e 100644 --- a/dbms/src/DataTypes/DataTypeAggregateFunction.cpp +++ b/dbms/src/DataTypes/DataTypeAggregateFunction.cpp @@ -82,7 +82,7 @@ void DataTypeAggregateFunction::deserializeBinary(IColumn & column, ReadBuffer & Arena & arena = column_concrete.createOrGetArena(); size_t size_of_state = function->sizeOfData(); - AggregateDataPtr place = arena.alloc(size_of_state); + AggregateDataPtr place = arena.alignedAlloc(size_of_state, function->alignOfData()); function->create(place); try @@ -123,13 +123,14 @@ void DataTypeAggregateFunction::deserializeBinaryBulk(IColumn & column, ReadBuff vec.reserve(vec.size() + limit); size_t size_of_state = function->sizeOfData(); + size_t align_of_state = function->alignOfData(); for (size_t i = 0; i < limit; ++i) { if (istr.eof()) break; - AggregateDataPtr place = arena.alloc(size_of_state); + AggregateDataPtr place = arena.alignedAlloc(size_of_state, align_of_state); function->create(place); @@ -160,7 +161,7 @@ static void deserializeFromString(const AggregateFunctionPtr & function, IColumn Arena & arena = column_concrete.createOrGetArena(); size_t size_of_state = function->sizeOfData(); - AggregateDataPtr place = arena.alloc(size_of_state); + AggregateDataPtr place = arena.alignedAlloc(size_of_state, function->alignOfData()); function->create(place); diff --git a/dbms/src/Interpreters/Aggregator.cpp b/dbms/src/Interpreters/Aggregator.cpp index dc9dde6acb8..db95a89badd 100644 --- a/dbms/src/Interpreters/Aggregator.cpp +++ b/dbms/src/Interpreters/Aggregator.cpp @@ -153,10 +153,34 @@ Aggregator::Aggregator(const Params & params_) total_size_of_aggregate_states = 0; all_aggregates_has_trivial_destructor = true; + // aggreate_states will be aligned as below: + // |<-- state_1 -->|<-- pad_1 -->|<-- state_2 -->|<-- pad_2 -->| ..... + // + // pad_N will be used to match alignment requirement for each next state. + // The address of state_1 is aligned based on maximum alignment requirements in states for (size_t i = 0; i < params.aggregates_size; ++i) { offsets_of_aggregate_states[i] = total_size_of_aggregate_states; + total_size_of_aggregate_states += params.aggregates[i].function->sizeOfData(); + + // aggreate states are aligned based on maximum requirement + align_aggregate_states = std::max(align_aggregate_states, + params.aggregates[i].function->alignOfData()); + + // If not the last aggregate_state, we need pad it so that next aggregate_state will be + // aligned. + if (i + 1 < params.aggregates_size) + { + size_t next_align_req = params.aggregates[i+1].function->alignOfData(); + if ((next_align_req & (next_align_req -1)) != 0) + { + throw Exception("alignOfData is not 2^N"); + } + // extend total_size to next alignment requirement + total_size_of_aggregate_states = + (total_size_of_aggregate_states & ~(next_align_req - 1)) + next_align_req; + } if (!params.aggregates[i].function->hasTrivialDestructor()) all_aggregates_has_trivial_destructor = false; @@ -613,7 +637,8 @@ void NO_INLINE Aggregator::executeImplCase( method.onNewKey(*it, params.keys_size, keys, *aggregates_pool); - AggregateDataPtr place = aggregates_pool->alloc(total_size_of_aggregate_states); + AggregateDataPtr place = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, + align_aggregate_states); createAggregateStates(place); aggregate_data = place; } @@ -731,7 +756,8 @@ bool Aggregator::executeOnBlock(const Block & block, AggregatedDataVariants & re if ((params.overflow_row || result.type == AggregatedDataVariants::Type::without_key) && !result.without_key) { - AggregateDataPtr place = result.aggregates_pool->alloc(total_size_of_aggregate_states); + AggregateDataPtr place = result.aggregates_pool->alignedAlloc(total_size_of_aggregate_states, + align_aggregate_states); createAggregateStates(place); result.without_key = place; } @@ -1899,7 +1925,8 @@ void NO_INLINE Aggregator::mergeStreamsImplCase( method.onNewKey(*it, params.keys_size, keys, *aggregates_pool); - AggregateDataPtr place = aggregates_pool->alloc(total_size_of_aggregate_states); + AggregateDataPtr place = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, + align_aggregate_states); createAggregateStates(place); aggregate_data = place; } @@ -1950,7 +1977,8 @@ void NO_INLINE Aggregator::mergeWithoutKeyStreamsImpl( AggregatedDataWithoutKey & res = result.without_key; if (!res) { - AggregateDataPtr place = result.aggregates_pool->alloc(total_size_of_aggregate_states); + AggregateDataPtr place = result.aggregates_pool->alignedAlloc(total_size_of_aggregate_states, + align_aggregate_states); createAggregateStates(place); res = place; } @@ -2002,7 +2030,7 @@ void Aggregator::mergeStream(const BlockInputStreamPtr & stream, AggregatedDataV * If there is at least one block with a bucket number greater than zero, then there was a two-level aggregation. */ auto max_bucket = bucket_to_blocks.rbegin()->first; - size_t has_two_level = max_bucket > 0; + size_t has_two_level = max_bucket >= 0; if (has_two_level) { diff --git a/dbms/src/Interpreters/Aggregator.h b/dbms/src/Interpreters/Aggregator.h index 5919a296085..78312fe2250 100644 --- a/dbms/src/Interpreters/Aggregator.h +++ b/dbms/src/Interpreters/Aggregator.h @@ -1154,6 +1154,11 @@ protected: Sizes offsets_of_aggregate_states; /// The offset to the n-th aggregate function in a row of aggregate functions. size_t total_size_of_aggregate_states = 0; /// The total size of the row from the aggregate functions. + + // add info to track aligned requirement + // If there are states whose alignment info were v1, ..vn, align_aggregate_states will be max(v1, ... vn) + size_t align_aggregate_states = 1; + bool all_aggregates_has_trivial_destructor = false; /// How many RAM were used to process the query before processing the first block. diff --git a/dbms/src/Interpreters/SpecializedAggregator.h b/dbms/src/Interpreters/SpecializedAggregator.h index db56ea37633..3313f03c551 100644 --- a/dbms/src/Interpreters/SpecializedAggregator.h +++ b/dbms/src/Interpreters/SpecializedAggregator.h @@ -191,7 +191,8 @@ void NO_INLINE Aggregator::executeSpecializedCase( method.onNewKey(*it, params.keys_size, keys, *aggregates_pool); - AggregateDataPtr place = aggregates_pool->alloc(total_size_of_aggregate_states); + AggregateDataPtr place = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, + align_aggregate_states); AggregateFunctionsList::forEach(AggregateFunctionsCreator( aggregate_functions, offsets_of_aggregate_states, place));