Aggregator merge and destroy states in batch

This commit is contained in:
Maksim Kita 2023-08-24 15:34:32 +03:00
parent 698ceee1a9
commit f0f2d416dd
4 changed files with 81 additions and 40 deletions

View File

@ -169,6 +169,10 @@ public:
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "merge() with thread pool parameter isn't implemented for {} ", getName());
}
/// Merges states (on which src places points to) with other states (on which dst places points to) of current aggregation function
/// then destroy states (on which src places points to).
virtual void mergeAndDestroyBatch(AggregateDataPtr * dst_places, AggregateDataPtr * src_places, size_t size, size_t offset, Arena * arena) const = 0;
/// Serializes state (to transmit it over the network, for example).
virtual void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version = std::nullopt) const = 0; /// NOLINT
@ -506,6 +510,15 @@ public:
static_cast<const Derived *>(this)->merge(places[i] + place_offset, rhs[i], arena);
}
void mergeAndDestroyBatch(AggregateDataPtr * dst_places, AggregateDataPtr * rhs_places, size_t size, size_t offset, Arena * arena) const override
{
for (size_t i = 0; i < size; ++i)
{
static_cast<const Derived *>(this)->merge(dst_places[i] + offset, rhs_places[i] + offset, arena);
static_cast<const Derived *>(this)->destroy(rhs_places[i] + offset);
}
}
void addBatchSinglePlace( /// NOLINT
size_t row_begin,
size_t row_end,

View File

@ -2479,48 +2479,21 @@ void NO_INLINE Aggregator::mergeDataNullKey(
}
}
template <typename Method, bool use_compiled_functions, bool prefetch, typename Table>
void NO_INLINE Aggregator::mergeDataImpl(Table & table_dst, Table & table_src, Arena * arena) const
{
if constexpr (Method::low_cardinality_optimization || Method::one_key_nullable_optimization)
mergeDataNullKey<Method, Table>(table_dst, table_src, arena);
PaddedPODArray<AggregateDataPtr> dst_places;
PaddedPODArray<AggregateDataPtr> src_places;
auto merge = [&](AggregateDataPtr & __restrict dst, AggregateDataPtr & __restrict src, bool inserted)
{
if (!inserted)
{
#if USE_EMBEDDED_COMPILER
if constexpr (use_compiled_functions)
{
const auto & compiled_functions = compiled_aggregate_functions_holder->compiled_aggregate_functions;
compiled_functions.merge_aggregate_states_function(dst, src);
if (compiled_aggregate_functions_holder->compiled_aggregate_functions.functions_count != params.aggregates_size)
{
for (size_t i = 0; i < params.aggregates_size; ++i)
{
if (!is_aggregate_function_compiled[i])
aggregate_functions[i]->merge(
dst + offsets_of_aggregate_states[i], src + offsets_of_aggregate_states[i], arena);
}
for (size_t i = 0; i < params.aggregates_size; ++i)
{
if (!is_aggregate_function_compiled[i])
aggregate_functions[i]->destroy(src + offsets_of_aggregate_states[i]);
}
}
}
else
#endif
{
for (size_t i = 0; i < params.aggregates_size; ++i)
aggregate_functions[i]->merge(dst + offsets_of_aggregate_states[i], src + offsets_of_aggregate_states[i], arena);
for (size_t i = 0; i < params.aggregates_size; ++i)
aggregate_functions[i]->destroy(src + offsets_of_aggregate_states[i]);
}
dst_places.push_back(dst);
src_places.push_back(src);
}
else
{
@ -2531,8 +2504,30 @@ void NO_INLINE Aggregator::mergeDataImpl(Table & table_dst, Table & table_src, A
};
table_src.template mergeToViaEmplace<decltype(merge), prefetch>(table_dst, std::move(merge));
table_src.clearAndShrink();
#if USE_EMBEDDED_COMPILER
if constexpr (use_compiled_functions)
{
const auto & compiled_functions = compiled_aggregate_functions_holder->compiled_aggregate_functions;
compiled_functions.merge_aggregate_states_function(dst_places.data(), src_places.data(), dst_places.size());
for (size_t i = 0; i < params.aggregates_size; ++i)
{
if (!is_aggregate_function_compiled[i])
aggregate_functions[i]->mergeAndDestroyBatch(
dst_places.data(), src_places.data(), dst_places.size(), offsets_of_aggregate_states[i], arena);
}
return;
}
#endif
for (size_t i = 0; i < params.aggregates_size; ++i)
{
aggregate_functions[i]->mergeAndDestroyBatch(
dst_places.data(), src_places.data(), dst_places.size(), offsets_of_aggregate_states[i], arena);
}
}

View File

@ -357,27 +357,60 @@ static void compileMergeAggregatesStates(llvm::Module & module, const std::vecto
llvm::IRBuilder<> b(module.getContext());
auto * aggregate_data_place_type = b.getInt8Ty()->getPointerTo();
auto * merge_aggregates_states_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { aggregate_data_place_type, aggregate_data_place_type }, false);
auto * merge_aggregates_states_func = llvm::Function::Create(merge_aggregates_states_func_declaration, llvm::Function::ExternalLinkage, name, module);
auto * aggregate_data_places_type = aggregate_data_place_type->getPointerTo();
auto * size_type = b.getInt64Ty();
auto * merge_aggregates_states_func_declaration
= llvm::FunctionType::get(b.getVoidTy(), {aggregate_data_places_type, aggregate_data_places_type, size_type}, false);
auto * merge_aggregates_states_func
= llvm::Function::Create(merge_aggregates_states_func_declaration, llvm::Function::ExternalLinkage, name, module);
auto * arguments = merge_aggregates_states_func->args().begin();
llvm::Value * aggregate_data_place_dst_arg = arguments++;
llvm::Value * aggregate_data_place_src_arg = arguments++;
llvm::Value * aggregate_data_places_dst_arg = arguments++;
llvm::Value * aggregate_data_places_src_arg = arguments++;
llvm::Value * aggregate_places_size_arg = arguments++;
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", merge_aggregates_states_func);
b.SetInsertPoint(entry);
/// Initialize loop
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", merge_aggregates_states_func);
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", merge_aggregates_states_func);
b.CreateCondBr(b.CreateICmpEQ(aggregate_places_size_arg, llvm::ConstantInt::get(size_type, 0)), end, loop);
b.SetInsertPoint(loop);
/// Loop
auto * counter_phi = b.CreatePHI(size_type, 2);
counter_phi->addIncoming(llvm::ConstantInt::get(size_type, 0), entry);
for (const auto & function_to_compile : functions)
{
auto * aggregate_data_place_dst = b.CreateLoad(aggregate_data_place_type,
b.CreateInBoundsGEP(aggregate_data_place_type->getPointerTo(), aggregate_data_places_dst_arg, counter_phi));
auto * aggregate_data_place_src = b.CreateLoad(aggregate_data_place_type,
b.CreateInBoundsGEP(aggregate_data_place_type->getPointerTo(), aggregate_data_places_src_arg, counter_phi));
size_t aggregate_function_offset = function_to_compile.aggregate_data_offset;
auto * aggregate_data_place_merge_dst_with_offset = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_place_dst_arg, aggregate_function_offset);
auto * aggregate_data_place_merge_src_with_offset = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_place_src_arg, aggregate_function_offset);
auto * aggregate_data_place_merge_dst_with_offset = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_place_dst, aggregate_function_offset);
auto * aggregate_data_place_merge_src_with_offset = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_place_src, aggregate_function_offset);
const auto * aggregate_function_ptr = function_to_compile.function;
aggregate_function_ptr->compileMerge(b, aggregate_data_place_merge_dst_with_offset, aggregate_data_place_merge_src_with_offset);
}
/// End of loop
auto * current_block = b.GetInsertBlock();
auto * incremeted_counter = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1));
counter_phi->addIncoming(incremeted_counter, current_block);
b.CreateCondBr(b.CreateICmpEQ(incremeted_counter, aggregate_places_size_arg), end, loop);
b.SetInsertPoint(end);
b.CreateRetVoid();
}

View File

@ -56,7 +56,7 @@ struct AggregateFunctionWithOffset
using JITCreateAggregateStatesFunction = void (*)(AggregateDataPtr);
using JITAddIntoAggregateStatesFunction = void (*)(ColumnDataRowsOffset, ColumnDataRowsOffset, ColumnData *, AggregateDataPtr *);
using JITAddIntoAggregateStatesFunctionSinglePlace = void (*)(ColumnDataRowsOffset, ColumnDataRowsOffset, ColumnData *, AggregateDataPtr);
using JITMergeAggregateStatesFunction = void (*)(AggregateDataPtr, AggregateDataPtr);
using JITMergeAggregateStatesFunction = void (*)(AggregateDataPtr *, AggregateDataPtr *, size_t);
using JITInsertAggregateStatesIntoColumnsFunction = void (*)(ColumnDataRowsOffset, ColumnDataRowsOffset, ColumnData *, AggregateDataPtr *);
struct CompiledAggregateFunctions