mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-10 01:25:21 +00:00
Aggregate functions update compile interface
This commit is contained in:
parent
3fe559b31f
commit
9b71b1040a
@ -1,9 +1,11 @@
|
||||
#include <functional>
|
||||
|
||||
/** Adapt functor to static method where functor passed as context.
|
||||
* Main use case to convert lambda into function that can be passed into JIT code.
|
||||
*/
|
||||
template <typename Functor>
|
||||
class FunctorToStaticMethodAdaptor : public FunctorToStaticMethodAdaptor<decltype(&Functor::operator())>
|
||||
{
|
||||
public:
|
||||
};
|
||||
|
||||
template <typename R, typename C, typename ...Args>
|
||||
|
@ -165,37 +165,6 @@ public:
|
||||
++this->data(place).denominator;
|
||||
}
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
virtual bool isCompilable() const override
|
||||
{
|
||||
using AverageFieldType = AvgFieldType<T>;
|
||||
return std::is_same_v<AverageFieldType, UInt64> || std::is_same_v<AverageFieldType, Int64>;
|
||||
}
|
||||
|
||||
virtual void compile(llvm::IRBuilderBase & builder, llvm::Value * aggregate_function_place, const DataTypePtr & value_type, llvm::Value * value) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
llvm::Type * numerator_type = b.getInt64Ty();
|
||||
llvm::Type * denominator_type = b.getInt64Ty();
|
||||
|
||||
auto * numerator_value_ptr = b.CreatePointerCast(aggregate_function_place, numerator_type->getPointerTo());
|
||||
auto * numerator_value = b.CreateLoad(numerator_type, numerator_value_ptr);
|
||||
auto * value_cast_to_result = nativeCast(b, value_type, value, numerator_type);
|
||||
auto * sum_result_value = numerator_value->getType()->isIntegerTy() ? b.CreateAdd(numerator_value, value_cast_to_result) : b.CreateFAdd(numerator_value, value_cast_to_result);
|
||||
b.CreateStore(sum_result_value, numerator_value_ptr);
|
||||
|
||||
auto * denominator_place_ptr_untyped = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_function_place, 8);
|
||||
auto * denominator_place_ptr = b.CreatePointerCast(denominator_place_ptr_untyped, denominator_type->getPointerTo());
|
||||
auto * denominator_value = b.CreateLoad(denominator_place_ptr, numerator_value_ptr);
|
||||
auto * increate_denominator_value = b.CreateAdd(denominator_value, llvm::ConstantInt::get(denominator_type, 1));
|
||||
b.CreateStore(increate_denominator_value, denominator_place_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
String getName() const final { return "avg"; }
|
||||
};
|
||||
}
|
||||
|
@ -395,20 +395,67 @@ public:
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
virtual bool isCompilable() const override { return Type == AggregateFunctionTypeSum; }
|
||||
bool isCompilable() const override
|
||||
{
|
||||
if constexpr (Type == AggregateFunctionTypeSumKahan)
|
||||
return false;
|
||||
|
||||
virtual void compile(llvm::IRBuilderBase & builder, llvm::Value * aggregate_function_place, const DataTypePtr & value_type, llvm::Value * value) const override
|
||||
auto return_type = getReturnType();
|
||||
|
||||
return canBeNativeType(*return_type);
|
||||
}
|
||||
|
||||
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_native_type = toNativeType(b, removeNullable(getReturnType()));
|
||||
auto * sum_value_ptr = b.CreatePointerCast(aggregate_function_place, return_native_type->getPointerTo());
|
||||
auto * sum_value = b.CreateLoad(return_native_type, sum_value_ptr);
|
||||
auto * value_cast_to_result = nativeCast(b, value_type, value, return_native_type);
|
||||
auto * return_type = toNativeType(b, removeNullable(getReturnType()));
|
||||
auto * aggregate_sum_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
|
||||
|
||||
b.CreateStore(llvm::ConstantInt::get(return_type, 0), aggregate_sum_ptr);
|
||||
}
|
||||
|
||||
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypePtr & value_type, llvm::Value * value) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, removeNullable(getReturnType()));
|
||||
|
||||
auto * sum_value_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
|
||||
auto * sum_value = b.CreateLoad(return_type, sum_value_ptr);
|
||||
|
||||
auto * value_cast_to_result = nativeCast(b, value_type, value, return_type);
|
||||
auto * sum_result_value = sum_value->getType()->isIntegerTy() ? b.CreateAdd(sum_value, value_cast_to_result) : b.CreateFAdd(sum_value, value_cast_to_result);
|
||||
|
||||
b.CreateStore(sum_result_value, sum_value_ptr);
|
||||
}
|
||||
|
||||
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, removeNullable(getReturnType()));
|
||||
|
||||
auto * sum_value_dst_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, return_type->getPointerTo());
|
||||
auto * sum_value_dst = b.CreateLoad(return_type, sum_value_dst_ptr);
|
||||
|
||||
auto * sum_value_src_ptr = b.CreatePointerCast(aggregate_data_src_ptr, return_type->getPointerTo());
|
||||
auto * sum_value_src = b.CreateLoad(return_type, sum_value_src_ptr);
|
||||
|
||||
auto * sum_return_value = b.CreateAdd(sum_value_dst, sum_value_src);
|
||||
b.CreateStore(sum_return_value, sum_value_dst_ptr);
|
||||
}
|
||||
|
||||
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
auto * return_type = toNativeType(b, removeNullable(getReturnType()));
|
||||
auto * sum_value_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
|
||||
|
||||
return b.CreateLoad(return_type, sum_value_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
private:
|
||||
|
@ -255,7 +255,22 @@ public:
|
||||
|
||||
virtual bool isCompilable() const { return false; }
|
||||
|
||||
virtual void compile(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_function_place*/, const DataTypePtr & /*value_type*/, llvm::Value * /*value*/) const
|
||||
virtual void compileCreate(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/) const
|
||||
{
|
||||
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
virtual void compileAdd(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/, const DataTypePtr & /*value_type*/, llvm::Value * /*value*/) const
|
||||
{
|
||||
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
virtual void compileMerge(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_dst_ptr*/, llvm::Value * /*aggregate_data_src_ptr*/) const
|
||||
{
|
||||
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
virtual llvm::Value * compileGetResult(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/) const
|
||||
{
|
||||
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
@ -214,6 +214,38 @@ void Aggregator::Params::explain(JSONBuilder::JSONMap & map) const
|
||||
}
|
||||
}
|
||||
|
||||
static CHJIT & getJITInstance()
|
||||
{
|
||||
static CHJIT jit;
|
||||
return jit;
|
||||
}
|
||||
|
||||
static std::string dumpAggregateFunction(const IAggregateFunction * function)
|
||||
{
|
||||
std::string function_dump;
|
||||
|
||||
auto return_type_name = function->getReturnType()->getName();
|
||||
|
||||
function_dump += return_type_name;
|
||||
function_dump += ' ';
|
||||
function_dump += function->getName();
|
||||
function_dump += '(';
|
||||
|
||||
const auto & argument_types = function->getArgumentTypes();
|
||||
for (const auto & argument_type : argument_types)
|
||||
{
|
||||
function_dump += argument_type->getName();
|
||||
function_dump += ',';
|
||||
}
|
||||
|
||||
if (!argument_types.empty())
|
||||
function_dump.pop_back();
|
||||
|
||||
function_dump += ')';
|
||||
|
||||
return function_dump;
|
||||
}
|
||||
|
||||
Aggregator::Aggregator(const Params & params_)
|
||||
: params(params_)
|
||||
{
|
||||
@ -265,8 +297,70 @@ Aggregator::Aggregator(const Params & params_)
|
||||
HashMethodContext::Settings cache_settings;
|
||||
cache_settings.max_threads = params.max_threads;
|
||||
aggregation_state_cache = AggregatedDataVariants::createCache(method_chosen, cache_settings);
|
||||
compileAggregateFunctions();
|
||||
}
|
||||
|
||||
void Aggregator::compileAggregateFunctions()
|
||||
{
|
||||
if (!params.compile_aggregate_expressions ||
|
||||
params.overflow_row)
|
||||
return;
|
||||
|
||||
std::vector<AggregateFunctionWithOffset> functions_to_compile;
|
||||
size_t aggregate_instructions_size = 0;
|
||||
std::string functions_dump;
|
||||
|
||||
/// Add values to the aggregate functions.
|
||||
for (size_t i = 0; i < aggregate_functions.size(); ++i)
|
||||
{
|
||||
const auto * function = aggregate_functions[i];
|
||||
size_t offset_of_aggregate_function = offsets_of_aggregate_states[i];
|
||||
|
||||
if (function && function->isCompilable())
|
||||
{
|
||||
AggregateFunctionWithOffset function_to_compile
|
||||
{
|
||||
.function = function,
|
||||
.aggregate_data_offset = offset_of_aggregate_function
|
||||
};
|
||||
|
||||
std::string function_dump = dumpAggregateFunction(function);
|
||||
functions_dump += function_dump;
|
||||
functions_dump += ' ';
|
||||
|
||||
functions_to_compile.emplace_back(std::move(function_to_compile));
|
||||
}
|
||||
|
||||
++aggregate_instructions_size;
|
||||
}
|
||||
|
||||
if (functions_to_compile.size() != aggregate_instructions_size)
|
||||
return;
|
||||
|
||||
CompiledAggregateFunctions compiled_aggregate_functions;
|
||||
|
||||
{
|
||||
static std::unordered_map<std::string, CompiledAggregateFunctions> aggregation_functions_dump_to_add_compiled;
|
||||
static std::mutex mtx;
|
||||
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
|
||||
auto it = aggregation_functions_dump_to_add_compiled.find(functions_dump);
|
||||
if (it != aggregation_functions_dump_to_add_compiled.end())
|
||||
{
|
||||
compiled_aggregate_functions = it->second;
|
||||
}
|
||||
else
|
||||
{
|
||||
LOG_TRACE(log, "Compile expression {}", functions_dump);
|
||||
|
||||
compiled_aggregate_functions = compileAggregateFunctons(getJITInstance(), functions_to_compile, functions_dump);
|
||||
aggregation_functions_dump_to_add_compiled[functions_dump] = compiled_aggregate_functions;
|
||||
}
|
||||
}
|
||||
|
||||
compiled_functions.emplace(std::move(compiled_aggregate_functions));
|
||||
}
|
||||
|
||||
AggregatedDataVariants::Type Aggregator::chooseAggregationMethod()
|
||||
{
|
||||
@ -480,10 +574,124 @@ void NO_INLINE Aggregator::executeImpl(
|
||||
executeImplBatch<true>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row);
|
||||
}
|
||||
|
||||
static CHJIT & getJITInstance()
|
||||
template <bool no_more_keys, typename Method>
|
||||
void NO_INLINE Aggregator::handleAggregationJIT(
|
||||
Method & method,
|
||||
typename Method::State & state,
|
||||
Arena * aggregates_pool,
|
||||
size_t rows,
|
||||
AggregateFunctionInstruction * aggregate_instructions) const
|
||||
{
|
||||
static CHJIT jit;
|
||||
return jit;
|
||||
std::vector<ColumnData> columns_data;
|
||||
columns_data.reserve(aggregate_functions.size());
|
||||
|
||||
/// Add values to the aggregate functions.
|
||||
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
|
||||
columns_data.emplace_back(getColumnData(inst->batch_arguments[0]));
|
||||
|
||||
auto add_into_aggregate_states_function = compiled_functions->add_into_aggregate_states_function;
|
||||
auto create_aggregate_states_function = compiled_functions->create_aggregate_states_function;
|
||||
|
||||
auto get_aggregate_data = [&](size_t row) -> AggregateDataPtr
|
||||
{
|
||||
AggregateDataPtr aggregate_data;
|
||||
|
||||
if constexpr (!no_more_keys)
|
||||
{
|
||||
auto emplace_result = state.emplaceKey(method.data, row, *aggregates_pool);
|
||||
|
||||
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
|
||||
if (emplace_result.isInserted())
|
||||
{
|
||||
/// exception-safety - if you can not allocate memory or create states, then destructors will not be called.
|
||||
emplace_result.setMapped(nullptr);
|
||||
|
||||
aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
|
||||
create_aggregate_states_function(aggregate_data);
|
||||
|
||||
emplace_result.setMapped(aggregate_data);
|
||||
}
|
||||
else
|
||||
aggregate_data = emplace_result.getMapped();
|
||||
|
||||
assert(aggregate_data != nullptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
/// Add only if the key already exists.
|
||||
/// Overflow row is disabled for JIT.
|
||||
auto find_result = state.findKey(method.data, row, *aggregates_pool);
|
||||
assert(find_result.getMapped() != nullptr);
|
||||
|
||||
aggregate_data = find_result.getMapped();
|
||||
}
|
||||
|
||||
return aggregate_data;
|
||||
};
|
||||
|
||||
GetAggregateDataFunction get_aggregate_data_function = FunctorToStaticMethodAdaptor<decltype(get_aggregate_data)>::unsafeCall;
|
||||
GetAggregateDataContext get_aggregate_data_context = reinterpret_cast<char *>(&get_aggregate_data);
|
||||
|
||||
add_into_aggregate_states_function(rows, columns_data.data(), get_aggregate_data_function, get_aggregate_data_context);
|
||||
}
|
||||
|
||||
template <bool no_more_keys, typename Method>
|
||||
void NO_INLINE Aggregator::handleAggregationDefault(
|
||||
Method & method,
|
||||
typename Method::State & state,
|
||||
Arena * aggregates_pool,
|
||||
size_t rows,
|
||||
AggregateFunctionInstruction * aggregate_instructions,
|
||||
AggregateDataPtr overflow_row) const
|
||||
{
|
||||
std::unique_ptr<AggregateDataPtr[]> places(new AggregateDataPtr[rows]);
|
||||
|
||||
/// For all rows.
|
||||
for (size_t i = 0; i < rows; ++i)
|
||||
{
|
||||
AggregateDataPtr aggregate_data;
|
||||
|
||||
if constexpr (!no_more_keys)
|
||||
{
|
||||
auto emplace_result = state.emplaceKey(method.data, i, *aggregates_pool);
|
||||
|
||||
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
|
||||
if (emplace_result.isInserted())
|
||||
{
|
||||
/// exception-safety - if you can not allocate memory or create states, then destructors will not be called.
|
||||
emplace_result.setMapped(nullptr);
|
||||
|
||||
aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
|
||||
createAggregateStates(aggregate_data);
|
||||
|
||||
emplace_result.setMapped(aggregate_data);
|
||||
}
|
||||
else
|
||||
aggregate_data = emplace_result.getMapped();
|
||||
|
||||
assert(aggregate_data != nullptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
/// Add only if the key already exists.
|
||||
auto find_result = state.findKey(method.data, i, *aggregates_pool);
|
||||
if (find_result.isFound())
|
||||
aggregate_data = find_result.getMapped();
|
||||
else
|
||||
aggregate_data = overflow_row;
|
||||
}
|
||||
|
||||
places[i] = aggregate_data;
|
||||
}
|
||||
|
||||
/// Add values to the aggregate functions.
|
||||
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
|
||||
{
|
||||
if (inst->offsets)
|
||||
inst->batch_that->addBatchArray(rows, places.get(), inst->state_offset, inst->batch_arguments, inst->offsets, aggregates_pool);
|
||||
else
|
||||
inst->batch_that->addBatch(rows, places.get(), inst->state_offset, inst->batch_arguments, aggregates_pool);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool no_more_keys, typename Method>
|
||||
@ -508,7 +716,7 @@ void NO_INLINE Aggregator::executeImplBatch(
|
||||
return;
|
||||
}
|
||||
|
||||
/// Optimization for special case when aggregating by 8bit key.
|
||||
/// Optimization for special case when aggregating by 8bit key.`
|
||||
if constexpr (!no_more_keys && std::is_same_v<Method, typename decltype(AggregatedDataVariants::key8)::element_type>)
|
||||
{
|
||||
/// We use another method if there are aggregate functions with -Array combinator.
|
||||
@ -543,178 +751,10 @@ void NO_INLINE Aggregator::executeImplBatch(
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic case.
|
||||
|
||||
auto get_aggregate_data = [&](size_t row) -> AggregateDataPtr
|
||||
{
|
||||
AggregateDataPtr aggregate_data;
|
||||
|
||||
if constexpr (!no_more_keys)
|
||||
{
|
||||
auto emplace_result = state.emplaceKey(method.data, row, *aggregates_pool);
|
||||
|
||||
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
|
||||
if (emplace_result.isInserted())
|
||||
{
|
||||
/// exception-safety - if you can not allocate memory or create states, then destructors will not be called.
|
||||
emplace_result.setMapped(nullptr);
|
||||
|
||||
aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
|
||||
createAggregateStates(aggregate_data);
|
||||
|
||||
emplace_result.setMapped(aggregate_data);
|
||||
}
|
||||
else
|
||||
aggregate_data = emplace_result.getMapped();
|
||||
|
||||
assert(aggregate_data != nullptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
/// Add only if the key already exists.
|
||||
auto find_result = state.findKey(method.data, row, *aggregates_pool);
|
||||
if (find_result.isFound())
|
||||
aggregate_data = find_result.getMapped();
|
||||
else
|
||||
aggregate_data = overflow_row;
|
||||
}
|
||||
|
||||
// std::cerr << "Row " << row << " returned place " << static_cast<void *>(aggregate_data) << std::endl;
|
||||
return aggregate_data;
|
||||
};
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
std::vector<ColumnData> columns_data;
|
||||
std::vector<AggregateFunctionToCompile> functions_to_compile;
|
||||
size_t aggregate_instructions_size = 0;
|
||||
|
||||
/// Add values to the aggregate functions.
|
||||
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
|
||||
{
|
||||
const auto * function = inst->that;
|
||||
if (function && function->isCompilable())
|
||||
{
|
||||
AggregateFunctionToCompile function_to_compile
|
||||
{
|
||||
.function = inst->that,
|
||||
.aggregate_data_offset = inst->state_offset
|
||||
};
|
||||
|
||||
columns_data.emplace_back(getColumnData(inst->batch_arguments[0]));
|
||||
functions_to_compile.emplace_back(std::move(function_to_compile));
|
||||
}
|
||||
|
||||
++aggregate_instructions_size;
|
||||
}
|
||||
|
||||
if (params.compile_aggregate_expressions && functions_to_compile.size() == aggregate_instructions_size)
|
||||
{
|
||||
std::string functions_dump;
|
||||
|
||||
for (const auto & func : functions_to_compile)
|
||||
{
|
||||
const auto * function = func.function;
|
||||
|
||||
std::string function_dump;
|
||||
|
||||
auto return_type_name = function->getReturnType()->getName();
|
||||
|
||||
function_dump += return_type_name;
|
||||
function_dump += ' ';
|
||||
function_dump += function->getName();
|
||||
function_dump += '(';
|
||||
|
||||
const auto & argument_types = function->getArgumentTypes();
|
||||
for (const auto & argument_type : argument_types)
|
||||
{
|
||||
function_dump += argument_type->getName();
|
||||
function_dump += ',';
|
||||
}
|
||||
|
||||
if (!argument_types.empty())
|
||||
function_dump.pop_back();
|
||||
|
||||
function_dump += ')';
|
||||
|
||||
functions_dump += function_dump;
|
||||
functions_dump += ' ';
|
||||
}
|
||||
|
||||
static std::unordered_map<std::string, CHJIT::CompiledModuleInfo> aggregation_functions_dump_to_compiled_module_info;
|
||||
CHJIT::CompiledModuleInfo compiled_module;
|
||||
|
||||
auto it = aggregation_functions_dump_to_compiled_module_info.find(functions_dump);
|
||||
if (it != aggregation_functions_dump_to_compiled_module_info.end())
|
||||
{
|
||||
compiled_module = it->second;
|
||||
LOG_TRACE(log, "Get compiled aggregate functions {} from cache", functions_dump);
|
||||
|
||||
}
|
||||
else
|
||||
{
|
||||
compiled_module = compileAggregateFunctons(getJITInstance(), functions_to_compile, functions_dump);
|
||||
aggregation_functions_dump_to_compiled_module_info[functions_dump] = compiled_module;
|
||||
}
|
||||
|
||||
LOG_TRACE(log, "Use compiled expression {}", functions_dump);
|
||||
|
||||
JITCompiledAggregateFunction aggregate_function = reinterpret_cast<JITCompiledAggregateFunction>(getJITInstance().findCompiledFunction(compiled_module, functions_dump));
|
||||
GetAggregateDataFunction get_aggregate_data_function = FunctorToStaticMethodAdaptor<decltype(get_aggregate_data)>::unsafeCall;
|
||||
GetAggregateDataContext get_aggregate_data_context = reinterpret_cast<char *>(&get_aggregate_data);
|
||||
aggregate_function(rows, columns_data.data(), get_aggregate_data_function, get_aggregate_data_context);
|
||||
}
|
||||
if (compiled_functions)
|
||||
handleAggregationJIT<no_more_keys>(method, state, aggregates_pool, rows, aggregate_instructions);
|
||||
else
|
||||
#endif
|
||||
{
|
||||
std::unique_ptr<AggregateDataPtr[]> places(new AggregateDataPtr[rows]);
|
||||
|
||||
/// For all rows.
|
||||
for (size_t i = 0; i < rows; ++i)
|
||||
{
|
||||
AggregateDataPtr aggregate_data;
|
||||
|
||||
if constexpr (!no_more_keys)
|
||||
{
|
||||
auto emplace_result = state.emplaceKey(method.data, i, *aggregates_pool);
|
||||
|
||||
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
|
||||
if (emplace_result.isInserted())
|
||||
{
|
||||
/// exception-safety - if you can not allocate memory or create states, then destructors will not be called.
|
||||
emplace_result.setMapped(nullptr);
|
||||
|
||||
aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
|
||||
createAggregateStates(aggregate_data);
|
||||
|
||||
emplace_result.setMapped(aggregate_data);
|
||||
}
|
||||
else
|
||||
aggregate_data = emplace_result.getMapped();
|
||||
|
||||
assert(aggregate_data != nullptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
/// Add only if the key already exists.
|
||||
auto find_result = state.findKey(method.data, i, *aggregates_pool);
|
||||
if (find_result.isFound())
|
||||
aggregate_data = find_result.getMapped();
|
||||
else
|
||||
aggregate_data = overflow_row;
|
||||
}
|
||||
|
||||
places[i] = aggregate_data;
|
||||
}
|
||||
|
||||
/// Add values to the aggregate functions.
|
||||
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
|
||||
{
|
||||
if (inst->offsets)
|
||||
inst->batch_that->addBatchArray(rows, places.get(), inst->state_offset, inst->batch_arguments, inst->offsets, aggregates_pool);
|
||||
else
|
||||
inst->batch_that->addBatch(rows, places.get(), inst->state_offset, inst->batch_arguments, aggregates_pool);
|
||||
}
|
||||
}
|
||||
handleAggregationDefault<no_more_keys>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row);
|
||||
}
|
||||
|
||||
|
||||
@ -1251,11 +1291,38 @@ void NO_INLINE Aggregator::convertToBlockImplFinal(
|
||||
auto shuffled_key_sizes = method.shuffleKeyColumns(key_columns, key_sizes);
|
||||
const auto & key_sizes_ref = shuffled_key_sizes ? *shuffled_key_sizes : key_sizes;
|
||||
|
||||
data.forEachValue([&](const auto & key, auto & mapped)
|
||||
if (compiled_functions)
|
||||
{
|
||||
method.insertKeyIntoColumns(key, key_columns, key_sizes_ref);
|
||||
insertAggregatesIntoColumns(mapped, final_aggregate_columns, arena);
|
||||
});
|
||||
std::unique_ptr<AggregateDataPtr[]> places(new AggregateDataPtr[data.size()]);
|
||||
size_t place_index = 0;
|
||||
|
||||
data.forEachValue([&](const auto & key, auto & mapped)
|
||||
{
|
||||
method.insertKeyIntoColumns(key, key_columns, key_sizes_ref);
|
||||
places[place_index] = mapped;
|
||||
++place_index;
|
||||
});
|
||||
|
||||
std::vector<ColumnData> columns_data;
|
||||
columns_data.reserve(final_aggregate_columns.size());
|
||||
|
||||
for (auto & final_aggregate_column : final_aggregate_columns)
|
||||
{
|
||||
final_aggregate_column = final_aggregate_column->cloneResized(data.size());
|
||||
columns_data.emplace_back(getColumnData(final_aggregate_column.get()));
|
||||
}
|
||||
|
||||
auto insert_aggregate_states_function = compiled_functions->insert_aggregates_into_columns_function;
|
||||
insert_aggregate_states_function(data.size(), columns_data.data(), places.get());
|
||||
}
|
||||
else
|
||||
{
|
||||
data.forEachValue([&](const auto & key, auto & mapped)
|
||||
{
|
||||
method.insertKeyIntoColumns(key, key_columns, key_sizes_ref);
|
||||
insertAggregatesIntoColumns(mapped, final_aggregate_columns, arena);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Method, typename Table>
|
||||
@ -1684,27 +1751,45 @@ void NO_INLINE Aggregator::mergeDataImpl(
|
||||
if constexpr (Method::low_cardinality_optimization)
|
||||
mergeDataNullKey<Method, Table>(table_dst, table_src, arena);
|
||||
|
||||
table_src.mergeToViaEmplace(table_dst,
|
||||
[&](AggregateDataPtr & __restrict dst, AggregateDataPtr & __restrict src, bool inserted)
|
||||
if (compiled_functions)
|
||||
{
|
||||
if (!inserted)
|
||||
{
|
||||
for (size_t i = 0; i < params.aggregates_size; ++i)
|
||||
aggregate_functions[i]->merge(
|
||||
dst + offsets_of_aggregate_states[i],
|
||||
src + offsets_of_aggregate_states[i],
|
||||
arena);
|
||||
auto merge_aggregate_states_function_typed = compiled_functions->merge_aggregate_states_function;
|
||||
|
||||
for (size_t i = 0; i < params.aggregates_size; ++i)
|
||||
aggregate_functions[i]->destroy(src + offsets_of_aggregate_states[i]);
|
||||
}
|
||||
else
|
||||
table_src.mergeToViaEmplace(table_dst, [&](AggregateDataPtr & __restrict dst, AggregateDataPtr & __restrict src, bool inserted)
|
||||
{
|
||||
dst = src;
|
||||
}
|
||||
if (!inserted)
|
||||
{
|
||||
merge_aggregate_states_function_typed(dst, src);
|
||||
}
|
||||
else
|
||||
{
|
||||
dst = src;
|
||||
}
|
||||
|
||||
src = nullptr;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
table_src.mergeToViaEmplace(table_dst, [&](AggregateDataPtr & __restrict dst, AggregateDataPtr & __restrict src, bool inserted)
|
||||
{
|
||||
if (!inserted)
|
||||
{
|
||||
for (size_t i = 0; i < params.aggregates_size; ++i)
|
||||
aggregate_functions[i]->merge(dst + offsets_of_aggregate_states[i], src + offsets_of_aggregate_states[i], arena);
|
||||
|
||||
for (size_t i = 0; i < params.aggregates_size; ++i)
|
||||
aggregate_functions[i]->destroy(src + offsets_of_aggregate_states[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dst = src;
|
||||
}
|
||||
|
||||
src = nullptr;
|
||||
});
|
||||
}
|
||||
|
||||
src = nullptr;
|
||||
});
|
||||
table_src.clearAndShrink();
|
||||
}
|
||||
|
||||
|
@ -26,6 +26,7 @@
|
||||
|
||||
#include <Interpreters/AggregateDescription.h>
|
||||
#include <Interpreters/AggregationCommon.h>
|
||||
#include <Interpreters/JIT/compileFunction.h>
|
||||
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Columns/ColumnFixedString.h>
|
||||
@ -1082,6 +1083,12 @@ private:
|
||||
/// For external aggregation.
|
||||
TemporaryFiles temporary_files;
|
||||
|
||||
std::optional<CompiledAggregateFunctions> compiled_functions;
|
||||
|
||||
/** Try to compile aggregate functions.
|
||||
*/
|
||||
void compileAggregateFunctions();
|
||||
|
||||
/** Select the aggregation method based on the number and types of keys. */
|
||||
AggregatedDataVariants::Type chooseAggregationMethod();
|
||||
|
||||
@ -1116,6 +1123,41 @@ private:
|
||||
AggregateFunctionInstruction * aggregate_instructions,
|
||||
AggregateDataPtr overflow_row) const;
|
||||
|
||||
template <bool no_more_keys, typename Method>
|
||||
void handleAggregationJIT(
|
||||
Method & method,
|
||||
typename Method::State & state,
|
||||
Arena * aggregates_pool,
|
||||
size_t rows,
|
||||
AggregateFunctionInstruction * aggregate_instructions) const;
|
||||
|
||||
// template <bool no_more_keys, typename Method>
|
||||
// void handleAggregationJITV2(
|
||||
// Method & method,
|
||||
// typename Method::State & state,
|
||||
// Arena * aggregates_pool,
|
||||
// size_t rows,
|
||||
// AggregateFunctionInstruction * aggregate_instructions,
|
||||
// AggregateDataPtr overflow_row) const;
|
||||
|
||||
// template <bool no_more_keys, typename Method>
|
||||
// void handleAggregationJITV3(
|
||||
// Method & method,
|
||||
// typename Method::State & state,
|
||||
// Arena * aggregates_pool,
|
||||
// size_t rows,
|
||||
// AggregateFunctionInstruction * aggregate_instructions,
|
||||
// AggregateDataPtr overflow_row) const;
|
||||
|
||||
template <bool no_more_keys, typename Method>
|
||||
void handleAggregationDefault(
|
||||
Method & method,
|
||||
typename Method::State & state,
|
||||
Arena * aggregates_pool,
|
||||
size_t rows,
|
||||
AggregateFunctionInstruction * aggregate_instructions,
|
||||
AggregateDataPtr overflow_row) const;
|
||||
|
||||
/// For case when there are no keys (all aggregate into one row).
|
||||
static void executeWithoutKeyImpl(
|
||||
AggregatedDataWithoutKey & res,
|
||||
|
@ -42,36 +42,29 @@ static Poco::Logger * getLogger()
|
||||
return &logger;
|
||||
}
|
||||
|
||||
class CompiledFunction
|
||||
class CompiledFunctionHolder
|
||||
{
|
||||
public:
|
||||
|
||||
CompiledFunction(void * compiled_function_, CHJIT::CompiledModuleInfo module_info_)
|
||||
explicit CompiledFunctionHolder(CompiledFunction compiled_function_)
|
||||
: compiled_function(compiled_function_)
|
||||
, module_info(std::move(module_info_))
|
||||
{}
|
||||
|
||||
void * getCompiledFunction() const { return compiled_function; }
|
||||
|
||||
~CompiledFunction()
|
||||
~CompiledFunctionHolder()
|
||||
{
|
||||
getJITInstance().deleteCompiledModule(module_info);
|
||||
getJITInstance().deleteCompiledModule(compiled_function.compiled_module);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
void * compiled_function;
|
||||
|
||||
CHJIT::CompiledModuleInfo module_info;
|
||||
CompiledFunction compiled_function;
|
||||
};
|
||||
|
||||
class LLVMExecutableFunction : public IExecutableFunction
|
||||
{
|
||||
public:
|
||||
|
||||
explicit LLVMExecutableFunction(const std::string & name_, std::shared_ptr<CompiledFunction> compiled_function_)
|
||||
explicit LLVMExecutableFunction(const std::string & name_, std::shared_ptr<CompiledFunctionHolder> compiled_function_holder_)
|
||||
: name(name_)
|
||||
, compiled_function(compiled_function_)
|
||||
, compiled_function_holder(compiled_function_holder_)
|
||||
{
|
||||
}
|
||||
|
||||
@ -104,8 +97,8 @@ public:
|
||||
|
||||
columns[arguments.size()] = getColumnData(result_column.get());
|
||||
|
||||
JITCompiledFunction jit_compiled_function_typed = reinterpret_cast<JITCompiledFunction>(compiled_function->getCompiledFunction());
|
||||
jit_compiled_function_typed(input_rows_count, columns.data());
|
||||
auto jit_compiled_function = compiled_function_holder->compiled_function.compiled_function;
|
||||
jit_compiled_function(input_rows_count, columns.data());
|
||||
|
||||
#if defined(MEMORY_SANITIZER)
|
||||
/// Memory sanitizer don't know about stores from JIT-ed code.
|
||||
@ -135,7 +128,7 @@ public:
|
||||
|
||||
private:
|
||||
std::string name;
|
||||
std::shared_ptr<CompiledFunction> compiled_function;
|
||||
std::shared_ptr<CompiledFunctionHolder> compiled_function_holder;
|
||||
};
|
||||
|
||||
class LLVMFunction : public IFunctionBase
|
||||
@ -157,9 +150,9 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void setCompiledFunction(std::shared_ptr<CompiledFunction> compiled_function_)
|
||||
void setCompiledFunction(std::shared_ptr<CompiledFunctionHolder> compiled_function_holder_)
|
||||
{
|
||||
compiled_function = compiled_function_;
|
||||
compiled_function_holder = compiled_function_holder_;
|
||||
}
|
||||
|
||||
bool isCompilable() const override { return true; }
|
||||
@ -177,10 +170,10 @@ public:
|
||||
|
||||
ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName &) const override
|
||||
{
|
||||
if (!compiled_function)
|
||||
if (!compiled_function_holder)
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Compiled function was not initialized {}", name);
|
||||
|
||||
return std::make_unique<LLVMExecutableFunction>(name, compiled_function);
|
||||
return std::make_unique<LLVMExecutableFunction>(name, compiled_function_holder);
|
||||
}
|
||||
|
||||
bool isDeterministic() const override
|
||||
@ -269,7 +262,7 @@ private:
|
||||
CompileDAG dag;
|
||||
DataTypes argument_types;
|
||||
std::vector<FunctionBasePtr> nested_functions;
|
||||
std::shared_ptr<CompiledFunction> compiled_function;
|
||||
std::shared_ptr<CompiledFunctionHolder> compiled_function_holder;
|
||||
};
|
||||
|
||||
static FunctionBasePtr compile(
|
||||
@ -293,22 +286,20 @@ static FunctionBasePtr compile(
|
||||
auto [compiled_function_cache_entry, _] = compilation_cache->getOrSet(hash_key, [&] ()
|
||||
{
|
||||
LOG_TRACE(getLogger(), "Compile expression {}", llvm_function->getName());
|
||||
CHJIT::CompiledModuleInfo compiled_module_info = compileFunction(getJITInstance(), *llvm_function);
|
||||
auto * compiled_jit_function = getJITInstance().findCompiledFunction(compiled_module_info, llvm_function->getName());
|
||||
auto compiled_function = std::make_shared<CompiledFunction>(compiled_jit_function, compiled_module_info);
|
||||
auto compiled_function = compileFunction(getJITInstance(), *llvm_function);
|
||||
auto compiled_function_holder = std::make_shared<CompiledFunctionHolder>(compiled_function);
|
||||
|
||||
return std::make_shared<CompiledFunctionCacheEntry>(std::move(compiled_function), compiled_module_info.size);
|
||||
return std::make_shared<CompiledFunctionCacheEntry>(std::move(compiled_function_holder), compiled_function.compiled_module.size);
|
||||
});
|
||||
|
||||
llvm_function->setCompiledFunction(compiled_function_cache_entry->getCompiledFunction());
|
||||
llvm_function->setCompiledFunction(compiled_function_cache_entry->getCompiledFunctionHolder());
|
||||
}
|
||||
else
|
||||
{
|
||||
LOG_TRACE(getLogger(), "Compile expression {}", llvm_function->getName());
|
||||
CHJIT::CompiledModuleInfo compiled_module_info = compileFunction(getJITInstance(), *llvm_function);
|
||||
auto * compiled_jit_function = getJITInstance().findCompiledFunction(compiled_module_info, llvm_function->getName());
|
||||
auto compiled_function = std::make_shared<CompiledFunction>(compiled_jit_function, compiled_module_info);
|
||||
llvm_function->setCompiledFunction(compiled_function);
|
||||
auto compiled_function = compileFunction(getJITInstance(), *llvm_function);
|
||||
auto compiled_function_ptr = std::make_shared<CompiledFunctionHolder>(compiled_function);
|
||||
|
||||
llvm_function->setCompiledFunction(compiled_function_ptr);
|
||||
}
|
||||
|
||||
return llvm_function;
|
||||
|
@ -11,22 +11,22 @@
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class CompiledFunction;
|
||||
class CompiledFunctionHolder;
|
||||
|
||||
class CompiledFunctionCacheEntry
|
||||
{
|
||||
public:
|
||||
CompiledFunctionCacheEntry(std::shared_ptr<CompiledFunction> compiled_function_, size_t compiled_function_size_)
|
||||
: compiled_function(std::move(compiled_function_))
|
||||
CompiledFunctionCacheEntry(std::shared_ptr<CompiledFunctionHolder> compiled_function_holder_, size_t compiled_function_size_)
|
||||
: compiled_function_holder(std::move(compiled_function_holder_))
|
||||
, compiled_function_size(compiled_function_size_)
|
||||
{}
|
||||
|
||||
std::shared_ptr<CompiledFunction> getCompiledFunction() const { return compiled_function; }
|
||||
std::shared_ptr<CompiledFunctionHolder> getCompiledFunctionHolder() const { return compiled_function_holder; }
|
||||
|
||||
size_t getCompiledFunctionSize() const { return compiled_function_size; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<CompiledFunction> compiled_function;
|
||||
std::shared_ptr<CompiledFunctionHolder> compiled_function_holder;
|
||||
|
||||
size_t compiled_function_size;
|
||||
};
|
||||
|
@ -189,7 +189,7 @@ CHJIT::CHJIT()
|
||||
|
||||
CHJIT::~CHJIT() = default;
|
||||
|
||||
CHJIT::CompiledModuleInfo CHJIT::compileModule(std::function<void (llvm::Module &)> compile_function)
|
||||
CHJIT::CompiledModule CHJIT::compileModule(std::function<void (llvm::Module &)> compile_function)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(jit_lock);
|
||||
|
||||
@ -210,12 +210,15 @@ std::unique_ptr<llvm::Module> CHJIT::createModuleForCompilation()
|
||||
return module;
|
||||
}
|
||||
|
||||
CHJIT::CompiledModuleInfo CHJIT::compileModule(std::unique_ptr<llvm::Module> module)
|
||||
CHJIT::CompiledModule CHJIT::compileModule(std::unique_ptr<llvm::Module> module)
|
||||
{
|
||||
runOptimizationPassesOnModule(*module);
|
||||
|
||||
auto buffer = compiler->compile(*module);
|
||||
|
||||
// llvm::errs() << "Module after optimizations " << "\n";
|
||||
// module->print(llvm::errs(), nullptr);
|
||||
|
||||
llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> object = llvm::object::ObjectFile::createObjectFile(*buffer);
|
||||
|
||||
if (!object)
|
||||
@ -234,7 +237,7 @@ CHJIT::CompiledModuleInfo CHJIT::compileModule(std::unique_ptr<llvm::Module> mod
|
||||
dynamic_linker.resolveRelocations();
|
||||
module_memory_manager->getManager().finalizeMemory();
|
||||
|
||||
CompiledModuleInfo module_info;
|
||||
CompiledModule compiled_module;
|
||||
|
||||
for (const auto & function : *module)
|
||||
{
|
||||
@ -250,47 +253,29 @@ CHJIT::CompiledModuleInfo CHJIT::compileModule(std::unique_ptr<llvm::Module> mod
|
||||
throw Exception(ErrorCodes::CANNOT_COMPILE_CODE, "DynamicLinker could not found symbol {} after compilation", function_name);
|
||||
|
||||
auto * jit_symbol_address = reinterpret_cast<void *>(jit_symbol.getAddress());
|
||||
|
||||
std::string symbol_name = std::to_string(current_module_key) + '_' + function_name;
|
||||
name_to_symbol[symbol_name] = jit_symbol_address;
|
||||
module_info.compiled_functions.emplace_back(std::move(function_name));
|
||||
compiled_module.function_name_to_symbol.emplace(std::move(function_name), jit_symbol_address);
|
||||
}
|
||||
|
||||
module_info.size = module_memory_manager->getAllocatedSize();
|
||||
module_info.identifier = current_module_key;
|
||||
compiled_module.size = module_memory_manager->getAllocatedSize();
|
||||
compiled_module.identifier = current_module_key;
|
||||
|
||||
module_identifier_to_memory_manager[current_module_key] = std::move(module_memory_manager);
|
||||
|
||||
compiled_code_size.fetch_add(module_info.size, std::memory_order_relaxed);
|
||||
compiled_code_size.fetch_add(compiled_module.size, std::memory_order_relaxed);
|
||||
|
||||
return module_info;
|
||||
return compiled_module;
|
||||
}
|
||||
|
||||
void CHJIT::deleteCompiledModule(const CHJIT::CompiledModuleInfo & module_info)
|
||||
void CHJIT::deleteCompiledModule(const CHJIT::CompiledModule & module)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(jit_lock);
|
||||
|
||||
auto module_it = module_identifier_to_memory_manager.find(module_info.identifier);
|
||||
auto module_it = module_identifier_to_memory_manager.find(module.identifier);
|
||||
if (module_it == module_identifier_to_memory_manager.end())
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "There is no compiled module with identifier {}", module_info.identifier);
|
||||
|
||||
for (const auto & function : module_info.compiled_functions)
|
||||
name_to_symbol.erase(function);
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "There is no compiled module with identifier {}", module.identifier);
|
||||
|
||||
module_identifier_to_memory_manager.erase(module_it);
|
||||
compiled_code_size.fetch_sub(module_info.size, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
void * CHJIT::findCompiledFunction(const CompiledModuleInfo & module_info, const std::string & function_name) const
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(jit_lock);
|
||||
|
||||
std::string symbol_name = std::to_string(module_info.identifier) + '_' + function_name;
|
||||
auto it = name_to_symbol.find(symbol_name);
|
||||
if (it != name_to_symbol.end())
|
||||
return it->second;
|
||||
|
||||
return nullptr;
|
||||
compiled_code_size.fetch_sub(module.size, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
void CHJIT::registerExternalSymbol(const std::string & symbol_name, void * address)
|
||||
|
@ -52,32 +52,31 @@ public:
|
||||
|
||||
~CHJIT();
|
||||
|
||||
struct CompiledModuleInfo
|
||||
struct CompiledModule
|
||||
{
|
||||
/// Size of compiled module code in bytes
|
||||
size_t size;
|
||||
|
||||
/// Module identifier. Should not be changed by client
|
||||
uint64_t identifier;
|
||||
/// Vector of compiled function nameds. Should not be changed by client
|
||||
std::vector<std::string> compiled_functions;
|
||||
|
||||
/// Vector of compiled functions. Should not be changed by client.
|
||||
/// It is client responsibility to cast result function to right signature.
|
||||
/// After call to deleteCompiledModule compiled functions from module become invalid.
|
||||
std::unordered_map<std::string, void *> function_name_to_symbol;
|
||||
|
||||
};
|
||||
|
||||
/** Compile module. In compile function client responsibility is to fill module with necessary
|
||||
* IR code, then it will be compiled by CHJIT instance.
|
||||
* Return compiled module info.
|
||||
* Return compiled module.
|
||||
*/
|
||||
CompiledModuleInfo compileModule(std::function<void (llvm::Module &)> compile_function);
|
||||
CompiledModule compileModule(std::function<void (llvm::Module &)> compile_function);
|
||||
|
||||
/** Delete compiled module. Pointers to functions from module become invalid after this call.
|
||||
* It is client responsibility to be sure that there are no pointers to compiled module code.
|
||||
*/
|
||||
void deleteCompiledModule(const CompiledModuleInfo & module_info);
|
||||
|
||||
/** Find compiled function using module_info, and function_name.
|
||||
* It is client responsibility to case result function to right signature.
|
||||
* After call to deleteCompiledModule compiled functions from module become invalid.
|
||||
*/
|
||||
void * findCompiledFunction(const CompiledModuleInfo & module_info, const std::string & function_name) const;
|
||||
void deleteCompiledModule(const CompiledModule & module_info);
|
||||
|
||||
/** Register external symbol for CHJIT instance to use, during linking.
|
||||
* It can be function, or global constant.
|
||||
@ -93,7 +92,7 @@ private:
|
||||
|
||||
std::unique_ptr<llvm::Module> createModuleForCompilation();
|
||||
|
||||
CompiledModuleInfo compileModule(std::unique_ptr<llvm::Module> module);
|
||||
CompiledModule compileModule(std::unique_ptr<llvm::Module> module);
|
||||
|
||||
std::string getMangledName(const std::string & name_to_mangle) const;
|
||||
|
||||
@ -107,7 +106,6 @@ private:
|
||||
std::unique_ptr<JITCompiler> compiler;
|
||||
std::unique_ptr<JITSymbolResolver> symbol_resolver;
|
||||
|
||||
std::unordered_map<std::string, void *> name_to_symbol;
|
||||
std::unordered_map<uint64_t, std::unique_ptr<JITModuleMemoryManager>> module_identifier_to_memory_manager;
|
||||
uint64_t current_module_key = 0;
|
||||
std::atomic<size_t> compiled_code_size = 0;
|
||||
|
@ -250,205 +250,288 @@ static void compileFunction(llvm::Module & module, const IFunctionBase & functio
|
||||
b.CreateRetVoid();
|
||||
}
|
||||
|
||||
CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBase & function)
|
||||
CompiledFunction compileFunction(CHJIT & jit, const IFunctionBase & function)
|
||||
{
|
||||
Stopwatch watch;
|
||||
|
||||
auto compiled_module_info = jit.compileModule([&](llvm::Module & module)
|
||||
auto compiled_module = jit.compileModule([&](llvm::Module & module)
|
||||
{
|
||||
compileFunction(module, function);
|
||||
});
|
||||
|
||||
ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds());
|
||||
ProfileEvents::increment(ProfileEvents::CompileExpressionsBytes, compiled_module_info.size);
|
||||
ProfileEvents::increment(ProfileEvents::CompileExpressionsBytes, compiled_module.size);
|
||||
ProfileEvents::increment(ProfileEvents::CompileFunction);
|
||||
|
||||
return compiled_module_info;
|
||||
auto compiled_function_ptr = reinterpret_cast<JITCompiledFunction>(compiled_module.function_name_to_symbol[function.getName()]);
|
||||
assert(compiled_function_ptr);
|
||||
|
||||
CompiledFunction result_compiled_function
|
||||
{
|
||||
.compiled_function = compiled_function_ptr,
|
||||
.compiled_module = compiled_module
|
||||
};
|
||||
|
||||
return result_compiled_function;
|
||||
}
|
||||
|
||||
CHJIT::CompiledModuleInfo compileAggregateFunctons(CHJIT & jit, const std::vector<AggregateFunctionToCompile> & functions, const std::string & result_name)
|
||||
static void compileCreateAggregateStatesFunctions(llvm::Module & module, const std::vector<AggregateFunctionWithOffset> & functions, const std::string & name)
|
||||
{
|
||||
auto compiled_module_info = jit.compileModule([&](llvm::Module & module)
|
||||
auto & context = module.getContext();
|
||||
llvm::IRBuilder<> b(context);
|
||||
|
||||
auto * aggregate_data_places_type = b.getInt8Ty()->getPointerTo();
|
||||
auto * create_aggregate_states_function_type = llvm::FunctionType::get(b.getVoidTy(), { aggregate_data_places_type }, false);
|
||||
auto * create_aggregate_states_function = llvm::Function::Create(create_aggregate_states_function_type, llvm::Function::ExternalLinkage, name, module);
|
||||
|
||||
auto * arguments = create_aggregate_states_function->args().begin();
|
||||
llvm::Value * aggregate_data_place_arg = arguments++;
|
||||
|
||||
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", create_aggregate_states_function);
|
||||
b.SetInsertPoint(entry);
|
||||
|
||||
std::vector<ColumnDataPlaceholder> columns(functions.size());
|
||||
for (const auto & function_to_compile : functions)
|
||||
{
|
||||
auto & context = module.getContext();
|
||||
llvm::IRBuilder<> b (context);
|
||||
size_t aggregate_function_offset = function_to_compile.aggregate_data_offset;
|
||||
const auto * aggregate_function = function_to_compile.function;
|
||||
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_arg, aggregate_function_offset);
|
||||
aggregate_function->compileCreate(b, aggregation_place_with_offset);
|
||||
}
|
||||
|
||||
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
|
||||
module.print(llvm::errs(), nullptr);
|
||||
|
||||
auto * column_data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
|
||||
auto * get_place_func_declaration = llvm::FunctionType::get(b.getInt8Ty()->getPointerTo(), { b.getInt8Ty()->getPointerTo(), size_type }, /*isVarArg=*/false);
|
||||
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { size_type, column_data_type->getPointerTo(), get_place_func_declaration->getPointerTo(), b.getInt8Ty()->getPointerTo() }, false);
|
||||
|
||||
auto * aggregate_loop_func_definition = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, result_name, module);
|
||||
|
||||
auto * arguments = aggregate_loop_func_definition->args().begin();
|
||||
llvm::Value * rows_count_arg = &*arguments++;
|
||||
llvm::Value * columns_arg = &*arguments++;
|
||||
llvm::Value * get_place_function_arg = &*arguments++;
|
||||
llvm::Value * get_place_function_context_arg = &*arguments++;
|
||||
|
||||
/// Initialize ColumnDataPlaceholder llvm representation of ColumnData
|
||||
|
||||
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func_definition);
|
||||
b.SetInsertPoint(entry);
|
||||
|
||||
std::vector<ColumnDataPlaceholder> columns(functions.size());
|
||||
for (size_t i = 0; i < functions.size(); ++i)
|
||||
{
|
||||
auto argument_type = functions[i].function->getArgumentTypes()[0];
|
||||
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, i));
|
||||
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(argument_type))->getPointerTo());
|
||||
}
|
||||
|
||||
/// Initialize loop
|
||||
|
||||
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", aggregate_loop_func_definition);
|
||||
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", aggregate_loop_func_definition);
|
||||
|
||||
b.CreateCondBr(b.CreateICmpEQ(rows_count_arg, llvm::ConstantInt::get(size_type, 0)), end, loop);
|
||||
|
||||
b.SetInsertPoint(loop);
|
||||
|
||||
auto * counter_phi = b.CreatePHI(rows_count_arg->getType(), 2);
|
||||
counter_phi->addIncoming(llvm::ConstantInt::get(size_type, 0), entry);
|
||||
|
||||
for (auto & col : columns)
|
||||
{
|
||||
col.data = b.CreatePHI(col.data_init->getType(), 2);
|
||||
col.data->addIncoming(col.data_init, entry);
|
||||
}
|
||||
|
||||
auto * aggregation_place = b.CreateCall(get_place_func_declaration, get_place_function_arg, { get_place_function_context_arg, counter_phi });
|
||||
|
||||
for (size_t i = 0; i < functions.size(); ++i)
|
||||
{
|
||||
size_t aggregate_function_offset = functions[i].aggregate_data_offset;
|
||||
const auto * aggregate_function_ptr = functions[i].function;
|
||||
|
||||
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregation_place, aggregate_function_offset);
|
||||
|
||||
auto column_type = functions[i].function->getArgumentTypes()[0];
|
||||
auto * column_data = b.CreateLoad(toNativeType(b, column_type), columns[i].data);
|
||||
aggregate_function_ptr->compile(b, aggregation_place_with_offset, column_type, column_data);
|
||||
}
|
||||
|
||||
/// End of loop
|
||||
|
||||
auto * cur_block = b.GetInsertBlock();
|
||||
for (auto & col : columns)
|
||||
{
|
||||
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
|
||||
if (col.null)
|
||||
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
|
||||
}
|
||||
|
||||
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1));
|
||||
counter_phi->addIncoming(value, loop);
|
||||
|
||||
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);
|
||||
|
||||
b.SetInsertPoint(end);
|
||||
b.CreateRetVoid();
|
||||
|
||||
llvm::errs() << "Module before optimizations \n";
|
||||
module.print(llvm::errs(), nullptr);
|
||||
});
|
||||
|
||||
return compiled_module_info;
|
||||
b.CreateRetVoid();
|
||||
}
|
||||
|
||||
CHJIT::CompiledModuleInfo compileAggregateFunctonsV2(CHJIT & jit, const std::vector<AggregateFunctionToCompile> & functions, const std::string & result_name)
|
||||
static void compileAddIntoAggregateStatesFunctions(llvm::Module & module, const std::vector<AggregateFunctionWithOffset> & functions, const std::string & name)
|
||||
{
|
||||
auto compiled_module_info = jit.compileModule([&](llvm::Module & module)
|
||||
auto & context = module.getContext();
|
||||
llvm::IRBuilder<> b(context);
|
||||
|
||||
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
|
||||
|
||||
auto * column_data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
|
||||
auto * get_place_func_declaration = llvm::FunctionType::get(b.getInt8Ty()->getPointerTo(), { b.getInt8Ty()->getPointerTo(), size_type }, /*isVarArg=*/false);
|
||||
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { size_type, column_data_type->getPointerTo(), get_place_func_declaration->getPointerTo(), b.getInt8Ty()->getPointerTo() }, false);
|
||||
|
||||
auto * aggregate_loop_func_definition = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, name, module);
|
||||
|
||||
auto * arguments = aggregate_loop_func_definition->args().begin();
|
||||
llvm::Value * rows_count_arg = arguments++;
|
||||
llvm::Value * columns_arg = arguments++;
|
||||
llvm::Value * get_place_function_arg = arguments++;
|
||||
llvm::Value * get_place_function_context_arg = arguments++;
|
||||
|
||||
/// Initialize ColumnDataPlaceholder llvm representation of ColumnData
|
||||
|
||||
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func_definition);
|
||||
b.SetInsertPoint(entry);
|
||||
|
||||
std::vector<ColumnDataPlaceholder> columns(functions.size());
|
||||
for (size_t i = 0; i < functions.size(); ++i)
|
||||
{
|
||||
auto & context = module.getContext();
|
||||
llvm::IRBuilder<> b (context);
|
||||
auto argument_type = functions[i].function->getArgumentTypes()[0];
|
||||
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, i));
|
||||
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(argument_type))->getPointerTo());
|
||||
}
|
||||
|
||||
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
|
||||
/// Initialize loop
|
||||
|
||||
auto * column_data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
|
||||
auto * aggregate_data_places_type = b.getInt8Ty()->getPointerTo()->getPointerTo();
|
||||
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { size_type, column_data_type->getPointerTo(), aggregate_data_places_type }, false);
|
||||
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", aggregate_loop_func_definition);
|
||||
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", aggregate_loop_func_definition);
|
||||
|
||||
auto * aggregate_loop_func_definition = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, result_name, module);
|
||||
b.CreateCondBr(b.CreateICmpEQ(rows_count_arg, llvm::ConstantInt::get(size_type, 0)), end, loop);
|
||||
|
||||
auto * arguments = aggregate_loop_func_definition->args().begin();
|
||||
llvm::Value * rows_count_arg = &*arguments++;
|
||||
llvm::Value * columns_arg = &*arguments++;
|
||||
llvm::Value * aggregate_data_places_arg = &*arguments++;
|
||||
b.SetInsertPoint(loop);
|
||||
|
||||
/// Initialize ColumnDataPlaceholder llvm representation of ColumnData
|
||||
auto * counter_phi = b.CreatePHI(rows_count_arg->getType(), 2);
|
||||
counter_phi->addIncoming(llvm::ConstantInt::get(size_type, 0), entry);
|
||||
|
||||
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func_definition);
|
||||
b.SetInsertPoint(entry);
|
||||
for (auto & col : columns)
|
||||
{
|
||||
col.data = b.CreatePHI(col.data_init->getType(), 2);
|
||||
col.data->addIncoming(col.data_init, entry);
|
||||
}
|
||||
|
||||
std::vector<ColumnDataPlaceholder> columns(functions.size());
|
||||
for (size_t i = 0; i < functions.size(); ++i)
|
||||
{
|
||||
auto argument_type = functions[i].function->getArgumentTypes()[0];
|
||||
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, i));
|
||||
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(argument_type))->getPointerTo());
|
||||
}
|
||||
auto * aggregation_place = b.CreateCall(get_place_func_declaration, get_place_function_arg, { get_place_function_context_arg, counter_phi });
|
||||
|
||||
/// Initialize loop
|
||||
for (size_t i = 0; i < functions.size(); ++i)
|
||||
{
|
||||
size_t aggregate_function_offset = functions[i].aggregate_data_offset;
|
||||
const auto * aggregate_function_ptr = functions[i].function;
|
||||
|
||||
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", aggregate_loop_func_definition);
|
||||
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", aggregate_loop_func_definition);
|
||||
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregation_place, aggregate_function_offset);
|
||||
|
||||
b.CreateCondBr(b.CreateICmpEQ(rows_count_arg, llvm::ConstantInt::get(size_type, 0)), end, loop);
|
||||
auto column_type = functions[i].function->getArgumentTypes()[0];
|
||||
auto * column_data = b.CreateLoad(toNativeType(b, column_type), columns[i].data);
|
||||
aggregate_function_ptr->compileAdd(b, aggregation_place_with_offset, column_type, column_data);
|
||||
}
|
||||
|
||||
b.SetInsertPoint(loop);
|
||||
/// End of loop
|
||||
|
||||
auto * counter_phi = b.CreatePHI(rows_count_arg->getType(), 2);
|
||||
counter_phi->addIncoming(llvm::ConstantInt::get(size_type, 0), entry);
|
||||
auto * cur_block = b.GetInsertBlock();
|
||||
for (auto & col : columns)
|
||||
{
|
||||
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
|
||||
if (col.null)
|
||||
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
|
||||
}
|
||||
|
||||
auto * aggregate_data_place_phi = b.CreatePHI(aggregate_data_places_type, 2);
|
||||
aggregate_data_place_phi->addIncoming(aggregate_data_places_arg, entry);
|
||||
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1));
|
||||
counter_phi->addIncoming(value, loop);
|
||||
|
||||
for (auto & col : columns)
|
||||
{
|
||||
col.data = b.CreatePHI(col.data_init->getType(), 2);
|
||||
col.data->addIncoming(col.data_init, entry);
|
||||
}
|
||||
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);
|
||||
|
||||
for (size_t i = 0; i < functions.size(); ++i)
|
||||
{
|
||||
size_t aggregate_function_offset = functions[i].aggregate_data_offset;
|
||||
const auto * aggregate_function_ptr = functions[i].function;
|
||||
b.SetInsertPoint(end);
|
||||
b.CreateRetVoid();
|
||||
}
|
||||
|
||||
auto * aggregate_data_place = b.CreateLoad(b.getInt8Ty()->getPointerTo(), aggregate_data_place_phi);
|
||||
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place, aggregate_function_offset);
|
||||
static void compileMergeAggregatesStates(llvm::Module & module, const std::vector<AggregateFunctionWithOffset> & functions, const std::string & name)
|
||||
{
|
||||
auto & context = module.getContext();
|
||||
llvm::IRBuilder<> b(context);
|
||||
|
||||
auto column_type = functions[i].function->getArgumentTypes()[0];
|
||||
auto * column_data = b.CreateLoad(toNativeType(b, column_type), columns[i].data);
|
||||
aggregate_function_ptr->compile(b, aggregation_place_with_offset, column_type, column_data);
|
||||
}
|
||||
auto * aggregate_data_places_type = b.getInt8Ty()->getPointerTo();
|
||||
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { aggregate_data_places_type, aggregate_data_places_type }, false);
|
||||
auto * aggregate_loop_func = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, name, module);
|
||||
|
||||
/// End of loop
|
||||
auto * arguments = aggregate_loop_func->args().begin();
|
||||
llvm::Value * aggregate_data_place_dst_arg = arguments++;
|
||||
llvm::Value * aggregate_data_place_src_arg = arguments++;
|
||||
|
||||
auto * cur_block = b.GetInsertBlock();
|
||||
for (auto & col : columns)
|
||||
{
|
||||
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
|
||||
if (col.null)
|
||||
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
|
||||
}
|
||||
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func);
|
||||
b.SetInsertPoint(entry);
|
||||
|
||||
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1), "", true, true);
|
||||
counter_phi->addIncoming(value, loop);
|
||||
for (const auto & function_to_compile : functions)
|
||||
{
|
||||
size_t aggregate_function_offset = function_to_compile.aggregate_data_offset;
|
||||
const auto * aggregate_function_ptr = function_to_compile.function;
|
||||
|
||||
aggregate_data_place_phi->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_phi, 1), loop);
|
||||
auto * aggregate_data_place_merge_dst_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_dst_arg, aggregate_function_offset);
|
||||
auto * aggregate_data_place_merge_src_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_src_arg, aggregate_function_offset);
|
||||
|
||||
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);
|
||||
aggregate_function_ptr->compileMerge(b, aggregate_data_place_merge_dst_with_offset, aggregate_data_place_merge_src_with_offset);
|
||||
}
|
||||
|
||||
b.SetInsertPoint(end);
|
||||
b.CreateRetVoid();
|
||||
b.CreateRetVoid();
|
||||
}
|
||||
|
||||
llvm::errs() << "Module before optimizations \n";
|
||||
module.print(llvm::errs(), nullptr);
|
||||
static void compileInsertAggregatesIntoResultColumns(llvm::Module & module, const std::vector<AggregateFunctionWithOffset> & functions, const std::string & name)
|
||||
{
|
||||
auto & context = module.getContext();
|
||||
llvm::IRBuilder<> b(context);
|
||||
|
||||
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
|
||||
|
||||
auto * column_data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
|
||||
auto * aggregate_data_places_type = b.getInt8Ty()->getPointerTo()->getPointerTo();
|
||||
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { size_type, column_data_type->getPointerTo(), aggregate_data_places_type }, false);
|
||||
auto * aggregate_loop_func = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, name, module);
|
||||
|
||||
auto * arguments = aggregate_loop_func->args().begin();
|
||||
llvm::Value * rows_count_arg = &*arguments++;
|
||||
llvm::Value * columns_arg = &*arguments++;
|
||||
llvm::Value * aggregate_data_places_arg = &*arguments++;
|
||||
|
||||
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func);
|
||||
b.SetInsertPoint(entry);
|
||||
|
||||
std::vector<ColumnDataPlaceholder> columns(functions.size());
|
||||
for (size_t i = 0; i < functions.size(); ++i)
|
||||
{
|
||||
auto return_type = functions[i].function->getReturnType();
|
||||
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, i));
|
||||
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(return_type))->getPointerTo());
|
||||
}
|
||||
|
||||
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", aggregate_loop_func);
|
||||
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", aggregate_loop_func);
|
||||
|
||||
b.CreateCondBr(b.CreateICmpEQ(rows_count_arg, llvm::ConstantInt::get(size_type, 0)), end, loop);
|
||||
|
||||
b.SetInsertPoint(loop);
|
||||
|
||||
auto * counter_phi = b.CreatePHI(rows_count_arg->getType(), 2);
|
||||
counter_phi->addIncoming(llvm::ConstantInt::get(size_type, 0), entry);
|
||||
|
||||
auto * aggregate_data_place_phi = b.CreatePHI(aggregate_data_places_type, 2);
|
||||
aggregate_data_place_phi->addIncoming(aggregate_data_places_arg, entry);
|
||||
|
||||
for (auto & col : columns)
|
||||
{
|
||||
col.data = b.CreatePHI(col.data_init->getType(), 2);
|
||||
col.data->addIncoming(col.data_init, entry);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < functions.size(); ++i)
|
||||
{
|
||||
size_t aggregate_function_offset = functions[i].aggregate_data_offset;
|
||||
const auto * aggregate_function_ptr = functions[i].function;
|
||||
|
||||
auto * aggregate_data_place = b.CreateLoad(b.getInt8Ty()->getPointerTo(), aggregate_data_place_phi);
|
||||
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place, aggregate_function_offset);
|
||||
|
||||
auto column_type = functions[i].function->getArgumentTypes()[0];
|
||||
auto * final_value = aggregate_function_ptr->compileGetResult(b, aggregation_place_with_offset);
|
||||
b.CreateStore(final_value, columns[i].data);
|
||||
}
|
||||
|
||||
/// End of loop
|
||||
|
||||
auto * cur_block = b.GetInsertBlock();
|
||||
for (auto & col : columns)
|
||||
{
|
||||
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
|
||||
if (col.null)
|
||||
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
|
||||
}
|
||||
|
||||
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1), "", true, true);
|
||||
counter_phi->addIncoming(value, loop);
|
||||
|
||||
aggregate_data_place_phi->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_phi, 1), loop);
|
||||
|
||||
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);
|
||||
|
||||
b.SetInsertPoint(end);
|
||||
b.CreateRetVoid();
|
||||
}
|
||||
|
||||
CompiledAggregateFunctions compileAggregateFunctons(CHJIT & jit, const std::vector<AggregateFunctionWithOffset> & functions, std::string functions_dump_name)
|
||||
{
|
||||
std::string create_aggregate_states_functions_name = functions_dump_name + "_create";
|
||||
std::string add_aggregate_states_functions_name = functions_dump_name + "_add";
|
||||
std::string merge_aggregate_states_functions_name = functions_dump_name + "_merge";
|
||||
std::string insert_aggregate_states_functions_name = functions_dump_name + "_insert";
|
||||
|
||||
auto compiled_module = jit.compileModule([&](llvm::Module & module)
|
||||
{
|
||||
compileCreateAggregateStatesFunctions(module, functions, create_aggregate_states_functions_name);
|
||||
compileAddIntoAggregateStatesFunctions(module, functions, add_aggregate_states_functions_name);
|
||||
compileMergeAggregatesStates(module, functions, merge_aggregate_states_functions_name);
|
||||
compileInsertAggregatesIntoResultColumns(module, functions, insert_aggregate_states_functions_name);
|
||||
});
|
||||
|
||||
return compiled_module_info;
|
||||
auto create_aggregate_states_function = reinterpret_cast<JITCreateAggregateStatesFunction>(compiled_module.function_name_to_symbol[create_aggregate_states_functions_name]);
|
||||
auto add_into_aggregate_states_function = reinterpret_cast<JITAddIntoAggregateStatesFunction>(compiled_module.function_name_to_symbol[add_aggregate_states_functions_name]);
|
||||
auto merge_aggregate_states_function = reinterpret_cast<JITMergeAggregateStatesFunction>(compiled_module.function_name_to_symbol[merge_aggregate_states_functions_name]);
|
||||
auto insert_aggregate_states_function = reinterpret_cast<JITInsertAggregatesIntoColumnsFunction>(compiled_module.function_name_to_symbol[insert_aggregate_states_functions_name]);
|
||||
|
||||
assert(create_aggregate_states_function);
|
||||
assert(add_into_aggregate_states_function);
|
||||
assert(merge_aggregate_states_function);
|
||||
assert(insert_aggregate_states_function);
|
||||
|
||||
CompiledAggregateFunctions compiled_aggregate_functions
|
||||
{
|
||||
.create_aggregate_states_function = create_aggregate_states_function,
|
||||
.add_into_aggregate_states_function = add_into_aggregate_states_function,
|
||||
.merge_aggregate_states_function = merge_aggregate_states_function,
|
||||
.insert_aggregates_into_columns_function = insert_aggregate_states_function
|
||||
};
|
||||
|
||||
return compiled_aggregate_functions;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -32,6 +32,14 @@ using ColumnDataRowsSize = size_t;
|
||||
|
||||
using JITCompiledFunction = void (*)(ColumnDataRowsSize, ColumnData *);
|
||||
|
||||
struct CompiledFunction
|
||||
{
|
||||
|
||||
JITCompiledFunction compiled_function;
|
||||
|
||||
CHJIT::CompiledModule compiled_module;
|
||||
};
|
||||
|
||||
/** Compile function to native jit code using CHJIT instance.
|
||||
* Function is compiled as single module.
|
||||
* After this function execution, code for function will be compiled and can be queried using
|
||||
@ -41,22 +49,33 @@ using JITCompiledFunction = void (*)(ColumnDataRowsSize, ColumnData *);
|
||||
* It is important that ColumnData parameter of JITCompiledFunction is result column,
|
||||
* and will be filled by compiled function.
|
||||
*/
|
||||
CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBase & function);
|
||||
CompiledFunction compileFunction(CHJIT & jit, const IFunctionBase & function);
|
||||
|
||||
using GetAggregateDataContext = char *;
|
||||
using GetAggregateDataFunction = AggregateDataPtr (*)(GetAggregateDataContext, size_t);
|
||||
using JITCompiledAggregateFunction = void (*)(ColumnDataRowsSize, ColumnData *, GetAggregateDataFunction, GetAggregateDataContext);
|
||||
|
||||
struct AggregateFunctionToCompile
|
||||
struct AggregateFunctionWithOffset
|
||||
{
|
||||
const IAggregateFunction * function;
|
||||
size_t aggregate_data_offset;
|
||||
};
|
||||
|
||||
CHJIT::CompiledModuleInfo compileAggregateFunctons(CHJIT & jit, const std::vector<AggregateFunctionToCompile> & functions, const std::string & result_name);
|
||||
using GetAggregateDataContext = char *;
|
||||
using GetAggregateDataFunction = AggregateDataPtr (*)(GetAggregateDataContext, size_t);
|
||||
|
||||
using JITCompiledAggregateFunctionV2 = void (*)(ColumnDataRowsSize, ColumnData *, AggregateDataPtr *);
|
||||
CHJIT::CompiledModuleInfo compileAggregateFunctonsV2(CHJIT & jit, const std::vector<AggregateFunctionToCompile> & functions, const std::string & result_name);
|
||||
using JITCreateAggregateStatesFunction = void (*)(AggregateDataPtr);
|
||||
using JITAddIntoAggregateStatesFunction = void (*)(ColumnDataRowsSize, ColumnData *, GetAggregateDataFunction, GetAggregateDataContext);
|
||||
using JITMergeAggregateStatesFunction = void (*)(AggregateDataPtr, AggregateDataPtr);
|
||||
using JITInsertAggregatesIntoColumnsFunction = void (*)(ColumnDataRowsSize, ColumnData *, AggregateDataPtr *);
|
||||
|
||||
struct CompiledAggregateFunctions
|
||||
{
|
||||
JITCreateAggregateStatesFunction create_aggregate_states_function;
|
||||
JITAddIntoAggregateStatesFunction add_into_aggregate_states_function;
|
||||
JITMergeAggregateStatesFunction merge_aggregate_states_function;
|
||||
JITInsertAggregatesIntoColumnsFunction insert_aggregates_into_columns_function;
|
||||
|
||||
CHJIT::CompiledModule compiled_module;
|
||||
};
|
||||
|
||||
CompiledAggregateFunctions compileAggregateFunctons(CHJIT & jit, const std::vector<AggregateFunctionWithOffset> & functions, std::string functions_dump_name);
|
||||
|
||||
}
|
||||
|
||||
|
@ -18,7 +18,7 @@ int main(int argc, char **argv)
|
||||
|
||||
jit.registerExternalSymbol("test_function", reinterpret_cast<void *>(&test_function));
|
||||
|
||||
auto compiled_module_info = jit.compileModule([](llvm::Module & module)
|
||||
auto compiled_module = jit.compileModule([](llvm::Module & module)
|
||||
{
|
||||
auto & context = module.getContext();
|
||||
llvm::IRBuilder<> b (context);
|
||||
@ -43,13 +43,14 @@ int main(int argc, char **argv)
|
||||
b.CreateRet(value);
|
||||
});
|
||||
|
||||
for (const auto & compiled_function_name : compiled_module_info.compiled_functions)
|
||||
for (const auto & [compiled_function_name, _] : compiled_module.function_name_to_symbol)
|
||||
{
|
||||
std::cerr << compiled_function_name << std::endl;
|
||||
}
|
||||
|
||||
int64_t value = 5;
|
||||
auto * test_name_function = reinterpret_cast<int64_t (*)(int64_t *)>(jit.findCompiledFunction(compiled_module_info, "test_name"));
|
||||
auto * symbol = compiled_module.function_name_to_symbol["test_name"];
|
||||
auto * test_name_function = reinterpret_cast<int64_t (*)(int64_t *)>(symbol);
|
||||
auto result = test_name_function(&value);
|
||||
std::cerr << "Result " << result << std::endl;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user