Compile aggregate functions without key

This commit is contained in:
Maksim Kita 2021-07-27 19:50:57 +03:00
parent 4862e9f80d
commit 3a6b37691a
7 changed files with 463 additions and 5 deletions

View File

@ -781,15 +781,48 @@ void NO_INLINE Aggregator::executeImplBatch(
}
template <bool use_compiled_functions>
void NO_INLINE Aggregator::executeWithoutKeyImpl(
AggregatedDataWithoutKey & res,
size_t rows,
AggregateFunctionInstruction * aggregate_instructions,
Arena * arena)
{
/// Adding values
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
#if USE_EMBEDDED_COMPILER
if constexpr (use_compiled_functions)
{
std::vector<ColumnData> columns_data;
for (size_t i = 0; i < aggregate_functions.size(); ++i)
{
if (!is_aggregate_function_compiled[i])
continue;
AggregateFunctionInstruction * inst = aggregate_instructions + i;
size_t arguments_size = inst->that->getArgumentTypes().size();
for (size_t argument_index = 0; argument_index < arguments_size; ++argument_index)
{
columns_data.emplace_back(getColumnData(inst->batch_arguments[argument_index]));
}
}
auto add_into_aggregate_states_function_single_place = compiled_aggregate_functions_holder->compiled_aggregate_functions.add_into_aggregate_states_function_single_place;
add_into_aggregate_states_function_single_place(rows, columns_data.data(), res);
}
#endif
/// Adding values
for (size_t i = 0; i < aggregate_functions.size(); ++i)
{
AggregateFunctionInstruction * inst = aggregate_instructions + i;
#if USE_EMBEDDED_COMPILER
if constexpr (use_compiled_functions)
if (is_aggregate_function_compiled[i])
continue;
#endif
if (inst->offsets)
inst->batch_that->addBatchSinglePlace(
inst->offsets[static_cast<ssize_t>(rows - 1)], res + inst->state_offset, inst->batch_arguments, arena);
@ -930,7 +963,16 @@ bool Aggregator::executeOnBlock(Columns columns, UInt64 num_rows, AggregatedData
/// For the case when there are no keys (all aggregate into one row).
if (result.type == AggregatedDataVariants::Type::without_key)
{
executeWithoutKeyImpl(result.without_key, num_rows, aggregate_functions_instructions.data(), result.aggregates_pool);
#if USE_EMBEDDED_COMPILER
if (compiled_aggregate_functions_holder)
{
executeWithoutKeyImpl<true>(result.without_key, num_rows, aggregate_functions_instructions.data(), result.aggregates_pool);
}
else
#endif
{
executeWithoutKeyImpl<false>(result.without_key, num_rows, aggregate_functions_instructions.data(), result.aggregates_pool);
}
}
else
{

View File

@ -1121,7 +1121,7 @@ private:
AggregateDataPtr overflow_row) const;
/// Specialization for a particular value no_more_keys.
template <bool no_more_keys, bool use_compiled_expressions, typename Method>
template <bool no_more_keys, bool use_compiled_functions, typename Method>
void executeImplBatch(
Method & method,
typename Method::State & state,
@ -1131,7 +1131,8 @@ private:
AggregateDataPtr overflow_row) const;
/// For case when there are no keys (all aggregate into one row).
static void executeWithoutKeyImpl(
template <bool use_compiled_functions>
void executeWithoutKeyImpl(
AggregatedDataWithoutKey & res,
size_t rows,
AggregateFunctionInstruction * aggregate_instructions,

View File

@ -436,6 +436,133 @@ static void compileAddIntoAggregateStatesFunctions(llvm::Module & module, const
b.CreateRetVoid();
}
static void compileAddIntoAggregateStatesFunctionsSinglePlace(llvm::Module & module, const std::vector<AggregateFunctionWithOffset> & functions, const std::string & name)
{
auto & context = module.getContext();
llvm::IRBuilder<> b(context);
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
auto * places_type = b.getInt8Ty()->getPointerTo();
auto * column_data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { size_type, column_data_type->getPointerTo(), places_type }, false);
auto * aggregate_loop_func_definition = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, name, module);
auto * arguments = aggregate_loop_func_definition->args().begin();
llvm::Value * rows_count_arg = arguments++;
llvm::Value * columns_arg = arguments++;
llvm::Value * place_arg = arguments++;
/// Initialize ColumnDataPlaceholder llvm representation of ColumnData
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func_definition);
b.SetInsertPoint(entry);
std::vector<ColumnDataPlaceholder> columns;
size_t previous_columns_size = 0;
for (const auto & function : functions)
{
auto argument_types = function.function->getArgumentTypes();
ColumnDataPlaceholder data_placeholder;
size_t function_arguments_size = argument_types.size();
for (size_t column_argument_index = 0; column_argument_index < function_arguments_size; ++column_argument_index)
{
const auto & argument_type = argument_types[column_argument_index];
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_64(column_data_type, columns_arg, previous_columns_size + column_argument_index));
data_placeholder.data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(argument_type))->getPointerTo());
data_placeholder.null_init = argument_type->isNullable() ? b.CreateExtractValue(data, {1}) : nullptr;
columns.emplace_back(data_placeholder);
}
previous_columns_size += function_arguments_size;
}
/// Initialize loop
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", aggregate_loop_func_definition);
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", aggregate_loop_func_definition);
b.CreateCondBr(b.CreateICmpEQ(rows_count_arg, llvm::ConstantInt::get(size_type, 0)), end, loop);
b.SetInsertPoint(loop);
auto * counter_phi = b.CreatePHI(rows_count_arg->getType(), 2);
counter_phi->addIncoming(llvm::ConstantInt::get(size_type, 0), entry);
for (auto & col : columns)
{
col.data = b.CreatePHI(col.data_init->getType(), 2);
col.data->addIncoming(col.data_init, entry);
if (col.null_init)
{
col.null = b.CreatePHI(col.null_init->getType(), 2);
col.null->addIncoming(col.null_init, entry);
}
}
previous_columns_size = 0;
for (const auto & function : functions)
{
size_t aggregate_function_offset = function.aggregate_data_offset;
const auto * aggregate_function_ptr = function.function;
auto arguments_types = function.function->getArgumentTypes();
std::vector<llvm::Value *> arguments_values;
size_t function_arguments_size = arguments_types.size();
arguments_values.resize(function_arguments_size);
for (size_t column_argument_index = 0; column_argument_index < function_arguments_size; ++column_argument_index)
{
auto * column_argument_data = columns[previous_columns_size + column_argument_index].data;
auto * column_argument_null_data = columns[previous_columns_size + column_argument_index].null;
auto & argument_type = arguments_types[column_argument_index];
auto * value = b.CreateLoad(toNativeType(b, removeNullable(argument_type)), column_argument_data);
if (!argument_type->isNullable())
{
arguments_values[column_argument_index] = value;
continue;
}
auto * is_null = b.CreateICmpNE(b.CreateLoad(b.getInt8Ty(), column_argument_null_data), b.getInt8(0));
auto * nullable_unitilized = llvm::Constant::getNullValue(toNativeType(b, argument_type));
auto * nullable_value = b.CreateInsertValue(b.CreateInsertValue(nullable_unitilized, value, {0}), is_null, {1});
arguments_values[column_argument_index] = nullable_value;
}
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_64(nullptr, place_arg, aggregate_function_offset);
aggregate_function_ptr->compileAdd(b, aggregation_place_with_offset, arguments_types, arguments_values);
previous_columns_size += function_arguments_size;
}
/// End of loop
auto * cur_block = b.GetInsertBlock();
for (auto & col : columns)
{
col.data->addIncoming(b.CreateConstInBoundsGEP1_64(nullptr, col.data, 1), cur_block);
if (col.null)
col.null->addIncoming(b.CreateConstInBoundsGEP1_64(nullptr, col.null, 1), cur_block);
}
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1));
counter_phi->addIncoming(value, cur_block);
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);
b.SetInsertPoint(end);
b.CreateRetVoid();
}
static void compileMergeAggregatesStates(llvm::Module & module, const std::vector<AggregateFunctionWithOffset> & functions, const std::string & name)
{
auto & context = module.getContext();
@ -569,6 +696,7 @@ CompiledAggregateFunctions compileAggregateFunctions(CHJIT & jit, const std::vec
std::string create_aggregate_states_functions_name = functions_dump_name + "_create";
std::string add_aggregate_states_functions_name = functions_dump_name + "_add";
std::string add_aggregate_states_functions_name_single_place = functions_dump_name + "_add_single_place";
std::string merge_aggregate_states_functions_name = functions_dump_name + "_merge";
std::string insert_aggregate_states_functions_name = functions_dump_name + "_insert";
@ -576,17 +704,20 @@ CompiledAggregateFunctions compileAggregateFunctions(CHJIT & jit, const std::vec
{
compileCreateAggregateStatesFunctions(module, functions, create_aggregate_states_functions_name);
compileAddIntoAggregateStatesFunctions(module, functions, add_aggregate_states_functions_name);
compileAddIntoAggregateStatesFunctionsSinglePlace(module, functions, add_aggregate_states_functions_name_single_place);
compileMergeAggregatesStates(module, functions, merge_aggregate_states_functions_name);
compileInsertAggregatesIntoResultColumns(module, functions, insert_aggregate_states_functions_name);
});
auto create_aggregate_states_function = reinterpret_cast<JITCreateAggregateStatesFunction>(compiled_module.function_name_to_symbol[create_aggregate_states_functions_name]);
auto add_into_aggregate_states_function = reinterpret_cast<JITAddIntoAggregateStatesFunction>(compiled_module.function_name_to_symbol[add_aggregate_states_functions_name]);
auto add_into_aggregate_states_function_single_place = reinterpret_cast<JITAddIntoAggregateStatesFunctionSinglePlace>(compiled_module.function_name_to_symbol[add_aggregate_states_functions_name_single_place]);
auto merge_aggregate_states_function = reinterpret_cast<JITMergeAggregateStatesFunction>(compiled_module.function_name_to_symbol[merge_aggregate_states_functions_name]);
auto insert_aggregate_states_function = reinterpret_cast<JITInsertAggregateStatesIntoColumnsFunction>(compiled_module.function_name_to_symbol[insert_aggregate_states_functions_name]);
assert(create_aggregate_states_function);
assert(add_into_aggregate_states_function);
assert(add_into_aggregate_states_function_single_place);
assert(merge_aggregate_states_function);
assert(insert_aggregate_states_function);
@ -598,6 +729,7 @@ CompiledAggregateFunctions compileAggregateFunctions(CHJIT & jit, const std::vec
{
.create_aggregate_states_function = create_aggregate_states_function,
.add_into_aggregate_states_function = add_into_aggregate_states_function,
.add_into_aggregate_states_function_single_place = add_into_aggregate_states_function_single_place,
.merge_aggregate_states_function = merge_aggregate_states_function,
.insert_aggregates_into_columns_function = insert_aggregate_states_function,

View File

@ -54,6 +54,7 @@ struct AggregateFunctionWithOffset
using JITCreateAggregateStatesFunction = void (*)(AggregateDataPtr);
using JITAddIntoAggregateStatesFunction = void (*)(ColumnDataRowsSize, ColumnData *, AggregateDataPtr *);
using JITAddIntoAggregateStatesFunctionSinglePlace = void (*)(ColumnDataRowsSize, ColumnData *, AggregateDataPtr);
using JITMergeAggregateStatesFunction = void (*)(AggregateDataPtr, AggregateDataPtr);
using JITInsertAggregateStatesIntoColumnsFunction = void (*)(ColumnDataRowsSize, ColumnData *, AggregateDataPtr *);
@ -61,6 +62,8 @@ struct CompiledAggregateFunctions
{
JITCreateAggregateStatesFunction create_aggregate_states_function;
JITAddIntoAggregateStatesFunction add_into_aggregate_states_function;
JITAddIntoAggregateStatesFunctionSinglePlace add_into_aggregate_states_function_single_place;
JITMergeAggregateStatesFunction merge_aggregate_states_function;
JITInsertAggregateStatesIntoColumnsFunction insert_aggregates_into_columns_function;
@ -75,6 +78,7 @@ struct CompiledAggregateFunctions
*
* JITCreateAggregateStatesFunction will initialize aggregate data ptr with initial aggregate states values.
* JITAddIntoAggregateStatesFunction will update aggregate states for aggregate functions with specified ColumnData.
* JITAddIntoAggregateStatesFunctionSinglePlace will update single aggregate state for aggregate functions with specified ColumnData.
* JITMergeAggregateStatesFunction will merge aggregate states for aggregate functions.
* JITInsertAggregateStatesIntoColumnsFunction will insert aggregate states for aggregate functions into result columns.
*/

View File

@ -150,6 +150,44 @@
FORMAT Null
</query>
<query>
SELECT
{function}(value_1),
{function}(value_2),
{function}(value_3)
FROM {table}
FORMAT Null
</query>
<query>
SELECT
{function}(value_1),
{function}(value_2),
sum(toUInt256(value_3)),
{function}(value_3)
FROM {table}
FORMAT Null
</query>
<query>
SELECT
{function}If(value_1, predicate),
{function}If(value_2, predicate),
{function}If(value_3, predicate)
FROM {table}
FORMAT Null
</query>
<query>
SELECT
{function}If(value_1, predicate),
{function}If(value_2, predicate),
sumIf(toUInt256(value_3), predicate),
{function}If(value_3, predicate)
FROM {table}
FORMAT Null
</query>
<query>
SELECT
{function}(value_1),
@ -200,6 +238,51 @@
FORMAT Null
</query>
<query>
SELECT
{function}(value_1),
{function}(value_2),
{function}(value_3),
{function}(value_4),
{function}(value_5)
FROM {table}
FORMAT Null
</query>
<query>
SELECT
{function}(value_1),
{function}(value_2),
sum(toUInt256(value_3)),
{function}(value_3),
{function}(value_4),
{function}(value_5)
FROM {table}
FORMAT Null
</query>
<query>
SELECT
{function}If(value_1, predicate),
{function}If(value_2, predicate),
{function}If(value_3, predicate),
{function}If(value_4, predicate),
{function}If(value_5, predicate)
FROM {table}
FORMAT Null
</query>
<query>
SELECT
{function}If(value_1, predicate),
{function}If(value_2, predicate),
sumIf(toUInt256(value_3), predicate),
{function}If(value_3, predicate),
{function}If(value_4, predicate),
{function}If(value_5, predicate)
FROM {table}
FORMAT Null
</query>
<query>
SELECT
@ -247,6 +330,48 @@
FORMAT Null
</query>
<query>
SELECT
{function}(WatchID),
{function}(CounterID),
{function}(ClientIP)
FROM hits_100m_single
FORMAT Null
</query>
<query>
SELECT
{function}(WatchID),
{function}(CounterID),
sum(toUInt256(ClientIP)),
{function}(ClientIP)
FROM hits_100m_single
FORMAT Null
</query>
<query>
SELECT
{function}(WatchID),
{function}(CounterID),
{function}(ClientIP),
{function}(IPNetworkID),
{function}(SearchEngineID)
FROM hits_100m_single
FORMAT Null
</query>
<query>
SELECT
{function}(WatchID),
{function}(CounterID),
sum(toUInt256(ClientIP)),
{function}(ClientIP),
{function}(IPNetworkID),
{function}(SearchEngineID)
FROM hits_100m_single
FORMAT Null
</query>
<query>
WITH (WatchID % 2 == 0) AS predicate
SELECT
@ -297,5 +422,51 @@
FORMAT Null
</query>
<query>
WITH (WatchID % 2 == 0) AS predicate
SELECT
{function}If(WatchID, predicate),
{function}If(CounterID, predicate),
{function}If(ClientIP, predicate)
FROM hits_100m_single
FORMAT Null
</query>
<query>
WITH (WatchID % 2 == 0) AS predicate
SELECT
{function}If(WatchID, predicate),
{function}If(CounterID, predicate),
sumIf(toUInt256(ClientIP), predicate),
{function}If(ClientIP, predicate)
FROM hits_100m_single
FORMAT Null
</query>
<query>
WITH (WatchID % 2 == 0) AS predicate
SELECT
{function}If(WatchID, predicate),
{function}If(CounterID, predicate),
{function}If(ClientIP, predicate),
{function}If(IPNetworkID, predicate),
{function}If(SearchEngineID, predicate)
FROM hits_100m_single
FORMAT Null
</query>
<query>
WITH (WatchID % 2 == 0) AS predicate
SELECT
{function}If(WatchID, predicate),
{function}If(CounterID, predicate),
sumIf(toUInt256(ClientIP), predicate),
{function}If(ClientIP, predicate),
{function}If(IPNetworkID, predicate),
{function}If(SearchEngineID, predicate)
FROM hits_100m_single
FORMAT Null
</query>
<drop_query>DROP TABLE IF EXISTS {table}</drop_query>
</test>

View File

@ -62,6 +62,12 @@ Simple functions if combinator
11482817 4611990575414646848 9223302669582414438 9828522700609834800 378121905921203.2 34845264.2080656 25993 9223372036854775806 4611686018427387904 4689180182672571856
63469 4612175339998036670 9222961628400798084 17239621485933250238 663164390134376.5 7825349797.6059 25996 9223372036854775806 4611686018427387904 2067736879306995526
29103473 4611744585914335132 9223035551850347954 12590190375872647672 525927999326314.7 26049107.15514301 23939 9223372036854775806 4611686018427387904 8318055464870862444
Simple functions without key
4611686725751467379 9223371678237104442 3626326766789368100 408650940859.2896 104735.01095549858 8873898 9223372036854775807 4611686018427387904 3818489297630359920
Simple functions with non compilable function without key
4611686725751467379 9223371678237104442 3626326766789368100 61384643584599682996279588 408650940859.2896 104735.01095549858 8873898 9223372036854775807 4611686018427387904 3818489297630359920
Simple functions if combinator without key
4611687533683519016 9223371678237104442 4124667747700004330 930178817930.5122 321189.2280948817 4434274 9223372036854775806 4611686018427387904 2265422677606390266
Aggregation without JIT compilation
Simple functions
1704509 4611700827100483880 9223360787015464643 10441337359398154812 19954243669348.844 9648741.579254271 523264 9223372036854775807 4611686018427387904 4544239379628300646
@ -126,3 +132,9 @@ Simple functions if combinator
11482817 4611990575414646848 9223302669582414438 9828522700609834800 378121905921203.2 34845264.2080656 25993 9223372036854775806 4611686018427387904 4689180182672571856
63469 4612175339998036670 9222961628400798084 17239621485933250238 663164390134376.5 7825349797.6059 25996 9223372036854775806 4611686018427387904 2067736879306995526
29103473 4611744585914335132 9223035551850347954 12590190375872647672 525927999326314.7 26049107.15514301 23939 9223372036854775806 4611686018427387904 8318055464870862444
Simple functions without key
4611686725751467379 9223371678237104442 3626326766789368100 408650940859.2896 104735.01095549858 8873898 9223372036854775807 4611686018427387904 3818489297630359920
Simple functions with non compilable function without key
4611686725751467379 9223371678237104442 3626326766789368100 61384643584599682996279588 408650940859.2896 104735.01095549858 8873898 9223372036854775807 4611686018427387904 3818489297630359920
Simple functions if combinator without key
4611687533683519016 9223371678237104442 4124667747700004330 930178817930.5122 321189.2280948817 4434274 9223372036854775806 4611686018427387904 2265422677606390266

View File

@ -53,6 +53,54 @@ SELECT
FROM test.hits
GROUP BY CounterID ORDER BY count() DESC LIMIT 20;
SELECT 'Simple functions without key';
SELECT
min(WatchID) AS min_watch_id,
max(WatchID),
sum(WatchID),
avg(WatchID),
avgWeighted(WatchID, CounterID),
count(WatchID),
groupBitOr(WatchID),
groupBitAnd(WatchID),
groupBitXor(WatchID)
FROM test.hits
ORDER BY min_watch_id DESC LIMIT 20;
SELECT 'Simple functions with non compilable function without key';
SELECT
min(WatchID) AS min_watch_id,
max(WatchID),
sum(WatchID),
sum(toUInt128(WatchID)),
avg(WatchID),
avgWeighted(WatchID, CounterID),
count(WatchID),
groupBitOr(WatchID),
groupBitAnd(WatchID),
groupBitXor(WatchID)
FROM test.hits
ORDER BY min_watch_id DESC LIMIT 20;
SELECT 'Simple functions if combinator without key';
WITH (WatchID % 2 == 0) AS predicate
SELECT
minIf(WatchID, predicate) as min_watch_id,
maxIf(WatchID, predicate),
sumIf(WatchID, predicate),
avgIf(WatchID, predicate),
avgWeightedIf(WatchID, CounterID, predicate),
countIf(WatchID, predicate),
groupBitOrIf(WatchID, predicate),
groupBitAndIf(WatchID, predicate),
groupBitXorIf(WatchID, predicate)
FROM test.hits
ORDER BY min_watch_id
DESC LIMIT 20;
SET compile_aggregate_expressions = 0;
SELECT 'Aggregation without JIT compilation';
@ -105,3 +153,51 @@ SELECT
groupBitXorIf(WatchID, predicate)
FROM test.hits
GROUP BY CounterID ORDER BY count() DESC LIMIT 20;
SELECT 'Simple functions without key';
SELECT
min(WatchID) AS min_watch_id,
max(WatchID),
sum(WatchID),
avg(WatchID),
avgWeighted(WatchID, CounterID),
count(WatchID),
groupBitOr(WatchID),
groupBitAnd(WatchID),
groupBitXor(WatchID)
FROM test.hits
ORDER BY min_watch_id DESC LIMIT 20;
SELECT 'Simple functions with non compilable function without key';
SELECT
min(WatchID) AS min_watch_id,
max(WatchID),
sum(WatchID),
sum(toUInt128(WatchID)),
avg(WatchID),
avgWeighted(WatchID, CounterID),
count(WatchID),
groupBitOr(WatchID),
groupBitAnd(WatchID),
groupBitXor(WatchID)
FROM test.hits
ORDER BY min_watch_id DESC LIMIT 20;
SELECT 'Simple functions if combinator without key';
WITH (WatchID % 2 == 0) AS predicate
SELECT
minIf(WatchID, predicate) as min_watch_id,
maxIf(WatchID, predicate),
sumIf(WatchID, predicate),
avgIf(WatchID, predicate),
avgWeightedIf(WatchID, CounterID, predicate),
countIf(WatchID, predicate),
groupBitOrIf(WatchID, predicate),
groupBitAndIf(WatchID, predicate),
groupBitXorIf(WatchID, predicate)
FROM test.hits
ORDER BY min_watch_id
DESC LIMIT 20;