Compile aggregate functions

This commit is contained in:
Maksim Kita 2021-05-31 11:05:40 +03:00
parent aa8d4aea54
commit 3fe559b31f
13 changed files with 499 additions and 18 deletions

View File

@ -0,0 +1,39 @@
#include <functional>
template <typename Functor>
class FunctorToStaticMethodAdaptor : public FunctorToStaticMethodAdaptor<decltype(&Functor::operator())>
{
public:
};
template <typename R, typename C, typename ...Args>
class FunctorToStaticMethodAdaptor<R (C::*)(Args...) const>
{
public:
static R call(C * ptr, Args... arguments)
{
return std::invoke(&C::operator(), ptr, arguments...);
}
static R unsafeCall(char * ptr, Args... arguments)
{
C * ptr_typed = reinterpret_cast<C*>(ptr);
return std::invoke(&C::operator(), ptr_typed, arguments...);
}
};
template <typename R, typename C, typename ...Args>
class FunctorToStaticMethodAdaptor<R (C::*)(Args...)>
{
public:
static R call(C * ptr, Args... arguments)
{
return std::invoke(&C::operator(), ptr, arguments...);
}
static R unsafeCall(char * ptr, Args... arguments)
{
C * ptr_typed = static_cast<C*>(ptr);
return std::invoke(&C::operator(), ptr_typed, arguments...);
}
};

View File

@ -9,6 +9,14 @@
#include <AggregateFunctions/IAggregateFunction.h>
#include <Core/DecimalFunctions.h>
#if !defined(ARCADIA_BUILD)
# include <Common/config.h>
#endif
#if USE_EMBEDDED_COMPILER
# include <llvm/IR/IRBuilder.h>
# include <DataTypes/Native.h>
#endif
namespace DB
{
@ -157,6 +165,37 @@ public:
++this->data(place).denominator;
}
#if USE_EMBEDDED_COMPILER
virtual bool isCompilable() const override
{
using AverageFieldType = AvgFieldType<T>;
return std::is_same_v<AverageFieldType, UInt64> || std::is_same_v<AverageFieldType, Int64>;
}
virtual void compile(llvm::IRBuilderBase & builder, llvm::Value * aggregate_function_place, const DataTypePtr & value_type, llvm::Value * value) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
llvm::Type * numerator_type = b.getInt64Ty();
llvm::Type * denominator_type = b.getInt64Ty();
auto * numerator_value_ptr = b.CreatePointerCast(aggregate_function_place, numerator_type->getPointerTo());
auto * numerator_value = b.CreateLoad(numerator_type, numerator_value_ptr);
auto * value_cast_to_result = nativeCast(b, value_type, value, numerator_type);
auto * sum_result_value = numerator_value->getType()->isIntegerTy() ? b.CreateAdd(numerator_value, value_cast_to_result) : b.CreateFAdd(numerator_value, value_cast_to_result);
b.CreateStore(sum_result_value, numerator_value_ptr);
auto * denominator_place_ptr_untyped = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_function_place, 8);
auto * denominator_place_ptr = b.CreatePointerCast(denominator_place_ptr_untyped, denominator_type->getPointerTo());
auto * denominator_value = b.CreateLoad(denominator_place_ptr, numerator_value_ptr);
auto * increate_denominator_value = b.CreateAdd(denominator_value, llvm::ConstantInt::get(denominator_type, 1));
b.CreateStore(increate_denominator_value, denominator_place_ptr);
}
#endif
String getName() const final { return "avg"; }
};
}

View File

@ -12,6 +12,14 @@
#include <AggregateFunctions/IAggregateFunction.h>
#if !defined(ARCADIA_BUILD)
# include <Common/config.h>
#endif
#if USE_EMBEDDED_COMPILER
# include <llvm/IR/IRBuilder.h>
# include <DataTypes/Native.h>
#endif
namespace DB
{
@ -385,6 +393,24 @@ public:
column.getData().push_back(this->data(place).get());
}
#if USE_EMBEDDED_COMPILER
virtual bool isCompilable() const override { return Type == AggregateFunctionTypeSum; }
virtual void compile(llvm::IRBuilderBase & builder, llvm::Value * aggregate_function_place, const DataTypePtr & value_type, llvm::Value * value) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_native_type = toNativeType(b, removeNullable(getReturnType()));
auto * sum_value_ptr = b.CreatePointerCast(aggregate_function_place, return_native_type->getPointerTo());
auto * sum_value = b.CreateLoad(return_native_type, sum_value_ptr);
auto * value_cast_to_result = nativeCast(b, value_type, value, return_native_type);
auto * sum_result_value = sum_value->getType()->isIntegerTy() ? b.CreateAdd(sum_value, value_cast_to_result) : b.CreateFAdd(sum_value, value_cast_to_result);
b.CreateStore(sum_result_value, sum_value_ptr);
}
#endif
private:
UInt32 scale;
};

View File

@ -9,11 +9,21 @@
#include <Common/Exception.h>
#include <common/types.h>
#if !defined(ARCADIA_BUILD)
# include "config_core.h"
#endif
#include <cstddef>
#include <memory>
#include <vector>
#include <type_traits>
namespace llvm
{
class LLVMContext;
class Value;
class IRBuilderBase;
}
namespace DB
{
@ -241,6 +251,17 @@ public:
// of true window functions, so this hack-ish interface suffices.
virtual bool isOnlyWindowFunction() const { return false; }
#if USE_EMBEDDED_COMPILER
virtual bool isCompilable() const { return false; }
virtual void compile(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_function_place*/, const DataTypePtr & /*value_type*/, llvm::Value * /*value*/) const
{
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
}
#endif
protected:
DataTypes argument_types;
Array parameters;

View File

@ -7,6 +7,11 @@
#pragma clang diagnostic ignored "-Wreserved-id-macro"
#endif
#undef __msan_unpoison
#undef __msan_test_shadow
#undef __msan_print_shadow
#undef __msan_unpoison_string
#define __msan_unpoison(X, Y)
#define __msan_test_shadow(X, Y) (false)
#define __msan_print_shadow(X, Y)

View File

@ -106,6 +106,8 @@ class IColumn;
M(Bool, allow_suspicious_low_cardinality_types, false, "In CREATE TABLE statement allows specifying LowCardinality modifier for types of small fixed size (8 or less). Enabling this may increase merge times and memory consumption.", 0) \
M(Bool, compile_expressions, true, "Compile some scalar functions and operators to native code.", 0) \
M(UInt64, min_count_to_compile_expression, 3, "The number of identical expressions before they are JIT-compiled", 0) \
M(Bool, compile_aggregate_expressions, true, "Compile aggregate functions to native code.", 0) \
M(UInt64, min_count_to_compile_aggregate_expression, 0, "The number of identical aggreagte expressions before they are JIT-compiled", 0) \
M(UInt64, group_by_two_level_threshold, 100000, "From what number of keys, a two-level aggregation starts. 0 - the threshold is not set.", 0) \
M(UInt64, group_by_two_level_threshold_bytes, 50000000, "From what size of the aggregation state in bytes, a two-level aggregation begins to be used. 0 - the threshold is not set. Two-level aggregation is used when at least one of the thresholds is triggered.", 0) \
M(Bool, distributed_aggregation_memory_efficient, true, "Is the memory-saving mode of distributed aggregation enabled.", 0) \

View File

@ -33,7 +33,8 @@ TTLAggregationAlgorithm::TTLAggregationAlgorithm(
Aggregator::Params params(header, keys, aggregates,
false, settings.max_rows_to_group_by, settings.group_by_overflow_mode, 0, 0,
settings.max_bytes_before_external_group_by, settings.empty_result_for_aggregation_by_empty_set,
storage_.getContext()->getTemporaryVolume(), settings.max_threads, settings.min_free_disk_space_for_temporary_data);
storage_.getContext()->getTemporaryVolume(), settings.max_threads, settings.min_free_disk_space_for_temporary_data,
settings.compile_aggregate_expressions, settings.min_count_to_compile_aggregate_expression);
aggregator = std::make_unique<Aggregator>(params);
}

View File

@ -1,5 +1,7 @@
#include <future>
#include <Poco/Util/Application.h>
#include <common/FunctorToStaticMethodAdaptor.h>
#include <Common/Stopwatch.h>
#include <Common/setThreadName.h>
#include <Common/formatReadable.h>
@ -21,6 +23,7 @@
#include <AggregateFunctions/AggregateFunctionArray.h>
#include <AggregateFunctions/AggregateFunctionState.h>
#include <IO/Operators.h>
#include <Interpreters/JIT/compileFunction.h>
namespace ProfileEvents
@ -477,6 +480,11 @@ void NO_INLINE Aggregator::executeImpl(
executeImplBatch<true>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row);
}
static CHJIT & getJITInstance()
{
static CHJIT jit;
return jit;
}
template <bool no_more_keys, typename Method>
void NO_INLINE Aggregator::executeImplBatch(
@ -537,16 +545,13 @@ void NO_INLINE Aggregator::executeImplBatch(
/// Generic case.
std::unique_ptr<AggregateDataPtr[]> places(new AggregateDataPtr[rows]);
/// For all rows.
for (size_t i = 0; i < rows; ++i)
auto get_aggregate_data = [&](size_t row) -> AggregateDataPtr
{
AggregateDataPtr aggregate_data = nullptr;
AggregateDataPtr aggregate_data;
if constexpr (!no_more_keys)
{
auto emplace_result = state.emplaceKey(method.data, i, *aggregates_pool);
auto emplace_result = state.emplaceKey(method.data, row, *aggregates_pool);
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
if (emplace_result.isInserted())
@ -567,23 +572,148 @@ void NO_INLINE Aggregator::executeImplBatch(
else
{
/// Add only if the key already exists.
auto find_result = state.findKey(method.data, i, *aggregates_pool);
auto find_result = state.findKey(method.data, row, *aggregates_pool);
if (find_result.isFound())
aggregate_data = find_result.getMapped();
else
aggregate_data = overflow_row;
}
places[i] = aggregate_data;
}
// std::cerr << "Row " << row << " returned place " << static_cast<void *>(aggregate_data) << std::endl;
return aggregate_data;
};
#if USE_EMBEDDED_COMPILER
std::vector<ColumnData> columns_data;
std::vector<AggregateFunctionToCompile> functions_to_compile;
size_t aggregate_instructions_size = 0;
/// Add values to the aggregate functions.
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
{
if (inst->offsets)
inst->batch_that->addBatchArray(rows, places.get(), inst->state_offset, inst->batch_arguments, inst->offsets, aggregates_pool);
const auto * function = inst->that;
if (function && function->isCompilable())
{
AggregateFunctionToCompile function_to_compile
{
.function = inst->that,
.aggregate_data_offset = inst->state_offset
};
columns_data.emplace_back(getColumnData(inst->batch_arguments[0]));
functions_to_compile.emplace_back(std::move(function_to_compile));
}
++aggregate_instructions_size;
}
if (params.compile_aggregate_expressions && functions_to_compile.size() == aggregate_instructions_size)
{
std::string functions_dump;
for (const auto & func : functions_to_compile)
{
const auto * function = func.function;
std::string function_dump;
auto return_type_name = function->getReturnType()->getName();
function_dump += return_type_name;
function_dump += ' ';
function_dump += function->getName();
function_dump += '(';
const auto & argument_types = function->getArgumentTypes();
for (const auto & argument_type : argument_types)
{
function_dump += argument_type->getName();
function_dump += ',';
}
if (!argument_types.empty())
function_dump.pop_back();
function_dump += ')';
functions_dump += function_dump;
functions_dump += ' ';
}
static std::unordered_map<std::string, CHJIT::CompiledModuleInfo> aggregation_functions_dump_to_compiled_module_info;
CHJIT::CompiledModuleInfo compiled_module;
auto it = aggregation_functions_dump_to_compiled_module_info.find(functions_dump);
if (it != aggregation_functions_dump_to_compiled_module_info.end())
{
compiled_module = it->second;
LOG_TRACE(log, "Get compiled aggregate functions {} from cache", functions_dump);
}
else
inst->batch_that->addBatch(rows, places.get(), inst->state_offset, inst->batch_arguments, aggregates_pool);
{
compiled_module = compileAggregateFunctons(getJITInstance(), functions_to_compile, functions_dump);
aggregation_functions_dump_to_compiled_module_info[functions_dump] = compiled_module;
}
LOG_TRACE(log, "Use compiled expression {}", functions_dump);
JITCompiledAggregateFunction aggregate_function = reinterpret_cast<JITCompiledAggregateFunction>(getJITInstance().findCompiledFunction(compiled_module, functions_dump));
GetAggregateDataFunction get_aggregate_data_function = FunctorToStaticMethodAdaptor<decltype(get_aggregate_data)>::unsafeCall;
GetAggregateDataContext get_aggregate_data_context = reinterpret_cast<char *>(&get_aggregate_data);
aggregate_function(rows, columns_data.data(), get_aggregate_data_function, get_aggregate_data_context);
}
else
#endif
{
std::unique_ptr<AggregateDataPtr[]> places(new AggregateDataPtr[rows]);
/// For all rows.
for (size_t i = 0; i < rows; ++i)
{
AggregateDataPtr aggregate_data;
if constexpr (!no_more_keys)
{
auto emplace_result = state.emplaceKey(method.data, i, *aggregates_pool);
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
if (emplace_result.isInserted())
{
/// exception-safety - if you can not allocate memory or create states, then destructors will not be called.
emplace_result.setMapped(nullptr);
aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
createAggregateStates(aggregate_data);
emplace_result.setMapped(aggregate_data);
}
else
aggregate_data = emplace_result.getMapped();
assert(aggregate_data != nullptr);
}
else
{
/// Add only if the key already exists.
auto find_result = state.findKey(method.data, i, *aggregates_pool);
if (find_result.isFound())
aggregate_data = find_result.getMapped();
else
aggregate_data = overflow_row;
}
places[i] = aggregate_data;
}
/// Add values to the aggregate functions.
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
{
if (inst->offsets)
inst->batch_that->addBatchArray(rows, places.get(), inst->state_offset, inst->batch_arguments, inst->offsets, aggregates_pool);
else
inst->batch_that->addBatch(rows, places.get(), inst->state_offset, inst->batch_arguments, aggregates_pool);
}
}
}

View File

@ -907,6 +907,10 @@ public:
size_t max_threads;
const size_t min_free_disk_space;
bool compile_aggregate_expressions;
size_t min_count_to_compile_aggregate_expression;
Params(
const Block & src_header_,
const ColumnNumbers & keys_, const AggregateDescriptions & aggregates_,
@ -916,6 +920,8 @@ public:
bool empty_result_for_aggregation_by_empty_set_,
VolumePtr tmp_volume_, size_t max_threads_,
size_t min_free_disk_space_,
bool compile_aggregate_expressions_,
size_t min_count_to_compile_aggregate_expression_,
const Block & intermediate_header_ = {})
: src_header(src_header_),
intermediate_header(intermediate_header_),
@ -925,14 +931,16 @@ public:
max_bytes_before_external_group_by(max_bytes_before_external_group_by_),
empty_result_for_aggregation_by_empty_set(empty_result_for_aggregation_by_empty_set_),
tmp_volume(tmp_volume_), max_threads(max_threads_),
min_free_disk_space(min_free_disk_space_)
min_free_disk_space(min_free_disk_space_),
compile_aggregate_expressions(compile_aggregate_expressions_),
min_count_to_compile_aggregate_expression(min_count_to_compile_aggregate_expression_)
{
}
/// Only parameters that matter during merge.
Params(const Block & intermediate_header_,
const ColumnNumbers & keys_, const AggregateDescriptions & aggregates_, bool overflow_row_, size_t max_threads_)
: Params(Block(), keys_, aggregates_, overflow_row_, 0, OverflowMode::THROW, 0, 0, 0, false, nullptr, max_threads_, 0)
: Params(Block(), keys_, aggregates_, overflow_row_, 0, OverflowMode::THROW, 0, 0, 0, false, nullptr, max_threads_, 0, false, 0)
{
intermediate_header = intermediate_header_;
}

View File

@ -2038,7 +2038,9 @@ void InterpreterSelectQuery::executeAggregation(QueryPlan & query_plan, const Ac
settings.empty_result_for_aggregation_by_empty_set,
context->getTemporaryVolume(),
settings.max_threads,
settings.min_free_disk_space_for_temporary_data);
settings.min_free_disk_space_for_temporary_data,
settings.compile_aggregate_expressions,
settings.min_count_to_compile_aggregate_expression);
SortDescription group_by_sort_description;
@ -2140,7 +2142,9 @@ void InterpreterSelectQuery::executeRollupOrCube(QueryPlan & query_plan, Modific
settings.empty_result_for_aggregation_by_empty_set,
context->getTemporaryVolume(),
settings.max_threads,
settings.min_free_disk_space_for_temporary_data);
settings.min_free_disk_space_for_temporary_data,
settings.compile_aggregate_expressions,
settings.min_count_to_compile_aggregate_expression);
auto transform_params = std::make_shared<AggregatingTransformParams>(params, true);

View File

@ -266,6 +266,191 @@ CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBase & fun
return compiled_module_info;
}
CHJIT::CompiledModuleInfo compileAggregateFunctons(CHJIT & jit, const std::vector<AggregateFunctionToCompile> & functions, const std::string & result_name)
{
auto compiled_module_info = jit.compileModule([&](llvm::Module & module)
{
auto & context = module.getContext();
llvm::IRBuilder<> b (context);
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
auto * column_data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
auto * get_place_func_declaration = llvm::FunctionType::get(b.getInt8Ty()->getPointerTo(), { b.getInt8Ty()->getPointerTo(), size_type }, /*isVarArg=*/false);
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { size_type, column_data_type->getPointerTo(), get_place_func_declaration->getPointerTo(), b.getInt8Ty()->getPointerTo() }, false);
auto * aggregate_loop_func_definition = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, result_name, module);
auto * arguments = aggregate_loop_func_definition->args().begin();
llvm::Value * rows_count_arg = &*arguments++;
llvm::Value * columns_arg = &*arguments++;
llvm::Value * get_place_function_arg = &*arguments++;
llvm::Value * get_place_function_context_arg = &*arguments++;
/// Initialize ColumnDataPlaceholder llvm representation of ColumnData
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func_definition);
b.SetInsertPoint(entry);
std::vector<ColumnDataPlaceholder> columns(functions.size());
for (size_t i = 0; i < functions.size(); ++i)
{
auto argument_type = functions[i].function->getArgumentTypes()[0];
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, i));
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(argument_type))->getPointerTo());
}
/// Initialize loop
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", aggregate_loop_func_definition);
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", aggregate_loop_func_definition);
b.CreateCondBr(b.CreateICmpEQ(rows_count_arg, llvm::ConstantInt::get(size_type, 0)), end, loop);
b.SetInsertPoint(loop);
auto * counter_phi = b.CreatePHI(rows_count_arg->getType(), 2);
counter_phi->addIncoming(llvm::ConstantInt::get(size_type, 0), entry);
for (auto & col : columns)
{
col.data = b.CreatePHI(col.data_init->getType(), 2);
col.data->addIncoming(col.data_init, entry);
}
auto * aggregation_place = b.CreateCall(get_place_func_declaration, get_place_function_arg, { get_place_function_context_arg, counter_phi });
for (size_t i = 0; i < functions.size(); ++i)
{
size_t aggregate_function_offset = functions[i].aggregate_data_offset;
const auto * aggregate_function_ptr = functions[i].function;
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregation_place, aggregate_function_offset);
auto column_type = functions[i].function->getArgumentTypes()[0];
auto * column_data = b.CreateLoad(toNativeType(b, column_type), columns[i].data);
aggregate_function_ptr->compile(b, aggregation_place_with_offset, column_type, column_data);
}
/// End of loop
auto * cur_block = b.GetInsertBlock();
for (auto & col : columns)
{
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
if (col.null)
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
}
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1));
counter_phi->addIncoming(value, loop);
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);
b.SetInsertPoint(end);
b.CreateRetVoid();
llvm::errs() << "Module before optimizations \n";
module.print(llvm::errs(), nullptr);
});
return compiled_module_info;
}
CHJIT::CompiledModuleInfo compileAggregateFunctonsV2(CHJIT & jit, const std::vector<AggregateFunctionToCompile> & functions, const std::string & result_name)
{
auto compiled_module_info = jit.compileModule([&](llvm::Module & module)
{
auto & context = module.getContext();
llvm::IRBuilder<> b (context);
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
auto * column_data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
auto * aggregate_data_places_type = b.getInt8Ty()->getPointerTo()->getPointerTo();
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { size_type, column_data_type->getPointerTo(), aggregate_data_places_type }, false);
auto * aggregate_loop_func_definition = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, result_name, module);
auto * arguments = aggregate_loop_func_definition->args().begin();
llvm::Value * rows_count_arg = &*arguments++;
llvm::Value * columns_arg = &*arguments++;
llvm::Value * aggregate_data_places_arg = &*arguments++;
/// Initialize ColumnDataPlaceholder llvm representation of ColumnData
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func_definition);
b.SetInsertPoint(entry);
std::vector<ColumnDataPlaceholder> columns(functions.size());
for (size_t i = 0; i < functions.size(); ++i)
{
auto argument_type = functions[i].function->getArgumentTypes()[0];
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, i));
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(argument_type))->getPointerTo());
}
/// Initialize loop
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", aggregate_loop_func_definition);
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", aggregate_loop_func_definition);
b.CreateCondBr(b.CreateICmpEQ(rows_count_arg, llvm::ConstantInt::get(size_type, 0)), end, loop);
b.SetInsertPoint(loop);
auto * counter_phi = b.CreatePHI(rows_count_arg->getType(), 2);
counter_phi->addIncoming(llvm::ConstantInt::get(size_type, 0), entry);
auto * aggregate_data_place_phi = b.CreatePHI(aggregate_data_places_type, 2);
aggregate_data_place_phi->addIncoming(aggregate_data_places_arg, entry);
for (auto & col : columns)
{
col.data = b.CreatePHI(col.data_init->getType(), 2);
col.data->addIncoming(col.data_init, entry);
}
for (size_t i = 0; i < functions.size(); ++i)
{
size_t aggregate_function_offset = functions[i].aggregate_data_offset;
const auto * aggregate_function_ptr = functions[i].function;
auto * aggregate_data_place = b.CreateLoad(b.getInt8Ty()->getPointerTo(), aggregate_data_place_phi);
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place, aggregate_function_offset);
auto column_type = functions[i].function->getArgumentTypes()[0];
auto * column_data = b.CreateLoad(toNativeType(b, column_type), columns[i].data);
aggregate_function_ptr->compile(b, aggregation_place_with_offset, column_type, column_data);
}
/// End of loop
auto * cur_block = b.GetInsertBlock();
for (auto & col : columns)
{
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
if (col.null)
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
}
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1), "", true, true);
counter_phi->addIncoming(value, loop);
aggregate_data_place_phi->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_phi, 1), loop);
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);
b.SetInsertPoint(end);
b.CreateRetVoid();
llvm::errs() << "Module before optimizations \n";
module.print(llvm::errs(), nullptr);
});
return compiled_module_info;
}
}
#endif

View File

@ -7,6 +7,7 @@
#if USE_EMBEDDED_COMPILER
#include <Functions/IFunction.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <Interpreters/JIT/CHJIT.h>
namespace DB
@ -28,6 +29,7 @@ struct ColumnData
ColumnData getColumnData(const IColumn * column);
using ColumnDataRowsSize = size_t;
using JITCompiledFunction = void (*)(ColumnDataRowsSize, ColumnData *);
/** Compile function to native jit code using CHJIT instance.
@ -41,6 +43,21 @@ using JITCompiledFunction = void (*)(ColumnDataRowsSize, ColumnData *);
*/
CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBase & function);
using GetAggregateDataContext = char *;
using GetAggregateDataFunction = AggregateDataPtr (*)(GetAggregateDataContext, size_t);
using JITCompiledAggregateFunction = void (*)(ColumnDataRowsSize, ColumnData *, GetAggregateDataFunction, GetAggregateDataContext);
struct AggregateFunctionToCompile
{
const IAggregateFunction * function;
size_t aggregate_data_offset;
};
CHJIT::CompiledModuleInfo compileAggregateFunctons(CHJIT & jit, const std::vector<AggregateFunctionToCompile> & functions, const std::string & result_name);
using JITCompiledAggregateFunctionV2 = void (*)(ColumnDataRowsSize, ColumnData *, AggregateDataPtr *);
CHJIT::CompiledModuleInfo compileAggregateFunctonsV2(CHJIT & jit, const std::vector<AggregateFunctionToCompile> & functions, const std::string & result_name);
}
#endif

View File

@ -301,6 +301,8 @@ QueryPlanPtr MergeTreeDataSelectExecutor::read(
context->getTemporaryVolume(),
settings.max_threads,
settings.min_free_disk_space_for_temporary_data,
settings.compile_expressions,
settings.min_count_to_compile_aggregate_expression,
header_before_aggregation); // The source header is also an intermediate header
transform_params = std::make_shared<AggregatingTransformParams>(std::move(params), query_info.projection->aggregate_final);
@ -329,7 +331,9 @@ QueryPlanPtr MergeTreeDataSelectExecutor::read(
settings.empty_result_for_aggregation_by_empty_set,
context->getTemporaryVolume(),
settings.max_threads,
settings.min_free_disk_space_for_temporary_data);
settings.min_free_disk_space_for_temporary_data,
settings.compile_aggregate_expressions,
settings.min_count_to_compile_aggregate_expression);
transform_params = std::make_shared<AggregatingTransformParams>(std::move(params), query_info.projection->aggregate_final);
}