From 90d8290d2828c0219045b5c95c0527b97900a904 Mon Sep 17 00:00:00 2001 From: "chenxing.xc" Date: Sun, 5 Aug 2018 16:45:15 +0800 Subject: [PATCH] aligned aggregate state --- dbms/src/Columns/ColumnAggregateFunction.cpp | 6 ++-- dbms/src/Common/Arena.h | 24 +++++++++++++ dbms/src/Interpreters/Aggregator.cpp | 36 +++++++++++++++++--- dbms/src/Interpreters/Aggregator.h | 5 +++ 4 files changed, 64 insertions(+), 7 deletions(-) diff --git a/dbms/src/Columns/ColumnAggregateFunction.cpp b/dbms/src/Columns/ColumnAggregateFunction.cpp index b989c007e56..07b2feecb53 100644 --- a/dbms/src/Columns/ColumnAggregateFunction.cpp +++ b/dbms/src/Columns/ColumnAggregateFunction.cpp @@ -264,7 +264,7 @@ void ColumnAggregateFunction::insert(const Field & x) Arena & arena = createOrGetArena(); - getData().push_back(arena.alloc(function->sizeOfData())); + getData().push_back(arena.alloc_align(function->sizeOfData(), function->alignOfData())); function->create(getData().back()); ReadBufferFromString read_buffer(x.get()); function->deserialize(getData().back(), read_buffer, &arena); @@ -276,7 +276,7 @@ void ColumnAggregateFunction::insertDefault() Arena & arena = createOrGetArena(); - getData().push_back(arena.alloc(function->sizeOfData())); + getData().push_back(arena.alloc_align(function->sizeOfData(), function->alignOfData())); function->create(getData().back()); } @@ -297,7 +297,7 @@ const char * ColumnAggregateFunction::deserializeAndInsertFromArena(const char * */ Arena & dst_arena = createOrGetArena(); - getData().push_back(dst_arena.alloc(function->sizeOfData())); + getData().push_back(dst_arena.alloc_align(function->sizeOfData(), function->alignOfData())); function->create(getData().back()); /** We will read from src_arena. diff --git a/dbms/src/Common/Arena.h b/dbms/src/Common/Arena.h index 7538b9a71a6..39ced9d002e 100644 --- a/dbms/src/Common/Arena.h +++ b/dbms/src/Common/Arena.h @@ -124,6 +124,30 @@ public: return res; } + /// Get peice of memory with alignment + char * alloc_align(size_t size, size_t align) + { + // Fast code path for non-alignment requirement + if (align <= 1) return alloc(size); + + size_t pos = (size_t)(head->pos); + // next pos match align requirement + pos = (pos & ~(align - 1)) + align; + + if (unlikely(pos + size > (size_t)(head->end))) + { + addChunk(size); + pos = (size_t)(head->pos); + pos = (pos & ~(align - 1)) + align; + } + + char * res = (char *)pos; + + head->pos = res + size; + + return res; + } + /** Rollback just performed allocation. * Must pass size not more that was just allocated. */ diff --git a/dbms/src/Interpreters/Aggregator.cpp b/dbms/src/Interpreters/Aggregator.cpp index ef6075e5761..aa6eac8cc0a 100644 --- a/dbms/src/Interpreters/Aggregator.cpp +++ b/dbms/src/Interpreters/Aggregator.cpp @@ -151,10 +151,34 @@ Aggregator::Aggregator(const Params & params_) total_size_of_aggregate_states = 0; all_aggregates_has_trivial_destructor = true; + // aggreate_states will be aligned as below: + // |<-- state_1 -->|<-- pad_1 -->|<-- state_2 -->|<-- pad_2 -->| ..... + // + // pad_N will be used to match alignment requirement for each next state. + // The address of state_1 is aligned based on maximum alignment requirements in states for (size_t i = 0; i < params.aggregates_size; ++i) { offsets_of_aggregate_states[i] = total_size_of_aggregate_states; + total_size_of_aggregate_states += params.aggregates[i].function->sizeOfData(); + + // aggreate states are aligned based on maximum requirement + align_aggregate_states = std::max(align_aggregate_states, + params.aggregates[i].function->alignOfData()); + + // If not the last aggregate_state, we need pad it so that next aggregate_state will be + // aligned. + if (i + 1 < params.aggregates_size) + { + size_t next_align_req = params.aggregates[i+1].function->alignOfData(); + if ((next_align_req & (next_align_req -1)) != 0) + { + throw Exception("alignOfData is not 2^N"); + } + // extend total_size to next alignment requirement + total_size_of_aggregate_states = + (total_size_of_aggregate_states & ~(next_align_req - 1)) + next_align_req; + } if (!params.aggregates[i].function->hasTrivialDestructor()) all_aggregates_has_trivial_destructor = false; @@ -605,7 +629,8 @@ void NO_INLINE Aggregator::executeImplCase( method.onNewKey(*it, params.keys_size, keys, *aggregates_pool); - AggregateDataPtr place = aggregates_pool->alloc(total_size_of_aggregate_states); + AggregateDataPtr place = aggregates_pool->alloc_align(total_size_of_aggregate_states, + align_aggregate_states); createAggregateStates(place); aggregate_data = place; } @@ -723,7 +748,8 @@ bool Aggregator::executeOnBlock(const Block & block, AggregatedDataVariants & re if ((params.overflow_row || result.type == AggregatedDataVariants::Type::without_key) && !result.without_key) { - AggregateDataPtr place = result.aggregates_pool->alloc(total_size_of_aggregate_states); + AggregateDataPtr place = result.aggregates_pool->alloc_align(total_size_of_aggregate_states, + align_aggregate_states); createAggregateStates(place); result.without_key = place; } @@ -1891,7 +1917,8 @@ void NO_INLINE Aggregator::mergeStreamsImplCase( method.onNewKey(*it, params.keys_size, keys, *aggregates_pool); - AggregateDataPtr place = aggregates_pool->alloc(total_size_of_aggregate_states); + AggregateDataPtr place = aggregates_pool->alloc_align(total_size_of_aggregate_states, + align_aggregate_states); createAggregateStates(place); aggregate_data = place; } @@ -1942,7 +1969,8 @@ void NO_INLINE Aggregator::mergeWithoutKeyStreamsImpl( AggregatedDataWithoutKey & res = result.without_key; if (!res) { - AggregateDataPtr place = result.aggregates_pool->alloc(total_size_of_aggregate_states); + AggregateDataPtr place = result.aggregates_pool->alloc_align(total_size_of_aggregate_states, + align_aggregate_states); createAggregateStates(place); res = place; } diff --git a/dbms/src/Interpreters/Aggregator.h b/dbms/src/Interpreters/Aggregator.h index 5919a296085..78312fe2250 100644 --- a/dbms/src/Interpreters/Aggregator.h +++ b/dbms/src/Interpreters/Aggregator.h @@ -1154,6 +1154,11 @@ protected: Sizes offsets_of_aggregate_states; /// The offset to the n-th aggregate function in a row of aggregate functions. size_t total_size_of_aggregate_states = 0; /// The total size of the row from the aggregate functions. + + // add info to track aligned requirement + // If there are states whose alignment info were v1, ..vn, align_aggregate_states will be max(v1, ... vn) + size_t align_aggregate_states = 1; + bool all_aggregates_has_trivial_destructor = false; /// How many RAM were used to process the query before processing the first block.