Aggregate functions update compile interface

This commit is contained in:
Maksim Kita 2021-06-03 22:20:53 +03:00
parent 3fe559b31f
commit 9b71b1040a
13 changed files with 718 additions and 481 deletions

View File

@ -1,9 +1,11 @@
#include <functional>
/** Adapt functor to static method where functor passed as context.
* Main use case to convert lambda into function that can be passed into JIT code.
*/
template <typename Functor>
class FunctorToStaticMethodAdaptor : public FunctorToStaticMethodAdaptor<decltype(&Functor::operator())>
{
public:
};
template <typename R, typename C, typename ...Args>

View File

@ -165,37 +165,6 @@ public:
++this->data(place).denominator;
}
#if USE_EMBEDDED_COMPILER
virtual bool isCompilable() const override
{
using AverageFieldType = AvgFieldType<T>;
return std::is_same_v<AverageFieldType, UInt64> || std::is_same_v<AverageFieldType, Int64>;
}
virtual void compile(llvm::IRBuilderBase & builder, llvm::Value * aggregate_function_place, const DataTypePtr & value_type, llvm::Value * value) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
llvm::Type * numerator_type = b.getInt64Ty();
llvm::Type * denominator_type = b.getInt64Ty();
auto * numerator_value_ptr = b.CreatePointerCast(aggregate_function_place, numerator_type->getPointerTo());
auto * numerator_value = b.CreateLoad(numerator_type, numerator_value_ptr);
auto * value_cast_to_result = nativeCast(b, value_type, value, numerator_type);
auto * sum_result_value = numerator_value->getType()->isIntegerTy() ? b.CreateAdd(numerator_value, value_cast_to_result) : b.CreateFAdd(numerator_value, value_cast_to_result);
b.CreateStore(sum_result_value, numerator_value_ptr);
auto * denominator_place_ptr_untyped = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_function_place, 8);
auto * denominator_place_ptr = b.CreatePointerCast(denominator_place_ptr_untyped, denominator_type->getPointerTo());
auto * denominator_value = b.CreateLoad(denominator_place_ptr, numerator_value_ptr);
auto * increate_denominator_value = b.CreateAdd(denominator_value, llvm::ConstantInt::get(denominator_type, 1));
b.CreateStore(increate_denominator_value, denominator_place_ptr);
}
#endif
String getName() const final { return "avg"; }
};
}

View File

@ -395,20 +395,67 @@ public:
#if USE_EMBEDDED_COMPILER
virtual bool isCompilable() const override { return Type == AggregateFunctionTypeSum; }
bool isCompilable() const override
{
if constexpr (Type == AggregateFunctionTypeSumKahan)
return false;
virtual void compile(llvm::IRBuilderBase & builder, llvm::Value * aggregate_function_place, const DataTypePtr & value_type, llvm::Value * value) const override
auto return_type = getReturnType();
return canBeNativeType(*return_type);
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_native_type = toNativeType(b, removeNullable(getReturnType()));
auto * sum_value_ptr = b.CreatePointerCast(aggregate_function_place, return_native_type->getPointerTo());
auto * sum_value = b.CreateLoad(return_native_type, sum_value_ptr);
auto * value_cast_to_result = nativeCast(b, value_type, value, return_native_type);
auto * return_type = toNativeType(b, removeNullable(getReturnType()));
auto * aggregate_sum_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
b.CreateStore(llvm::ConstantInt::get(return_type, 0), aggregate_sum_ptr);
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypePtr & value_type, llvm::Value * value) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, removeNullable(getReturnType()));
auto * sum_value_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
auto * sum_value = b.CreateLoad(return_type, sum_value_ptr);
auto * value_cast_to_result = nativeCast(b, value_type, value, return_type);
auto * sum_result_value = sum_value->getType()->isIntegerTy() ? b.CreateAdd(sum_value, value_cast_to_result) : b.CreateFAdd(sum_value, value_cast_to_result);
b.CreateStore(sum_result_value, sum_value_ptr);
}
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, removeNullable(getReturnType()));
auto * sum_value_dst_ptr = b.CreatePointerCast(aggregate_data_dst_ptr, return_type->getPointerTo());
auto * sum_value_dst = b.CreateLoad(return_type, sum_value_dst_ptr);
auto * sum_value_src_ptr = b.CreatePointerCast(aggregate_data_src_ptr, return_type->getPointerTo());
auto * sum_value_src = b.CreateLoad(return_type, sum_value_src_ptr);
auto * sum_return_value = b.CreateAdd(sum_value_dst, sum_value_src);
b.CreateStore(sum_return_value, sum_value_dst_ptr);
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * return_type = toNativeType(b, removeNullable(getReturnType()));
auto * sum_value_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
return b.CreateLoad(return_type, sum_value_ptr);
}
#endif
private:

View File

@ -255,7 +255,22 @@ public:
virtual bool isCompilable() const { return false; }
virtual void compile(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_function_place*/, const DataTypePtr & /*value_type*/, llvm::Value * /*value*/) const
virtual void compileCreate(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/) const
{
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
}
virtual void compileAdd(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/, const DataTypePtr & /*value_type*/, llvm::Value * /*value*/) const
{
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
}
virtual void compileMerge(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_dst_ptr*/, llvm::Value * /*aggregate_data_src_ptr*/) const
{
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
}
virtual llvm::Value * compileGetResult(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/) const
{
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
}

View File

@ -214,6 +214,38 @@ void Aggregator::Params::explain(JSONBuilder::JSONMap & map) const
}
}
static CHJIT & getJITInstance()
{
static CHJIT jit;
return jit;
}
static std::string dumpAggregateFunction(const IAggregateFunction * function)
{
std::string function_dump;
auto return_type_name = function->getReturnType()->getName();
function_dump += return_type_name;
function_dump += ' ';
function_dump += function->getName();
function_dump += '(';
const auto & argument_types = function->getArgumentTypes();
for (const auto & argument_type : argument_types)
{
function_dump += argument_type->getName();
function_dump += ',';
}
if (!argument_types.empty())
function_dump.pop_back();
function_dump += ')';
return function_dump;
}
Aggregator::Aggregator(const Params & params_)
: params(params_)
{
@ -265,8 +297,70 @@ Aggregator::Aggregator(const Params & params_)
HashMethodContext::Settings cache_settings;
cache_settings.max_threads = params.max_threads;
aggregation_state_cache = AggregatedDataVariants::createCache(method_chosen, cache_settings);
compileAggregateFunctions();
}
void Aggregator::compileAggregateFunctions()
{
if (!params.compile_aggregate_expressions ||
params.overflow_row)
return;
std::vector<AggregateFunctionWithOffset> functions_to_compile;
size_t aggregate_instructions_size = 0;
std::string functions_dump;
/// Add values to the aggregate functions.
for (size_t i = 0; i < aggregate_functions.size(); ++i)
{
const auto * function = aggregate_functions[i];
size_t offset_of_aggregate_function = offsets_of_aggregate_states[i];
if (function && function->isCompilable())
{
AggregateFunctionWithOffset function_to_compile
{
.function = function,
.aggregate_data_offset = offset_of_aggregate_function
};
std::string function_dump = dumpAggregateFunction(function);
functions_dump += function_dump;
functions_dump += ' ';
functions_to_compile.emplace_back(std::move(function_to_compile));
}
++aggregate_instructions_size;
}
if (functions_to_compile.size() != aggregate_instructions_size)
return;
CompiledAggregateFunctions compiled_aggregate_functions;
{
static std::unordered_map<std::string, CompiledAggregateFunctions> aggregation_functions_dump_to_add_compiled;
static std::mutex mtx;
std::lock_guard<std::mutex> lock(mtx);
auto it = aggregation_functions_dump_to_add_compiled.find(functions_dump);
if (it != aggregation_functions_dump_to_add_compiled.end())
{
compiled_aggregate_functions = it->second;
}
else
{
LOG_TRACE(log, "Compile expression {}", functions_dump);
compiled_aggregate_functions = compileAggregateFunctons(getJITInstance(), functions_to_compile, functions_dump);
aggregation_functions_dump_to_add_compiled[functions_dump] = compiled_aggregate_functions;
}
}
compiled_functions.emplace(std::move(compiled_aggregate_functions));
}
AggregatedDataVariants::Type Aggregator::chooseAggregationMethod()
{
@ -480,10 +574,124 @@ void NO_INLINE Aggregator::executeImpl(
executeImplBatch<true>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row);
}
static CHJIT & getJITInstance()
template <bool no_more_keys, typename Method>
void NO_INLINE Aggregator::handleAggregationJIT(
Method & method,
typename Method::State & state,
Arena * aggregates_pool,
size_t rows,
AggregateFunctionInstruction * aggregate_instructions) const
{
static CHJIT jit;
return jit;
std::vector<ColumnData> columns_data;
columns_data.reserve(aggregate_functions.size());
/// Add values to the aggregate functions.
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
columns_data.emplace_back(getColumnData(inst->batch_arguments[0]));
auto add_into_aggregate_states_function = compiled_functions->add_into_aggregate_states_function;
auto create_aggregate_states_function = compiled_functions->create_aggregate_states_function;
auto get_aggregate_data = [&](size_t row) -> AggregateDataPtr
{
AggregateDataPtr aggregate_data;
if constexpr (!no_more_keys)
{
auto emplace_result = state.emplaceKey(method.data, row, *aggregates_pool);
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
if (emplace_result.isInserted())
{
/// exception-safety - if you can not allocate memory or create states, then destructors will not be called.
emplace_result.setMapped(nullptr);
aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
create_aggregate_states_function(aggregate_data);
emplace_result.setMapped(aggregate_data);
}
else
aggregate_data = emplace_result.getMapped();
assert(aggregate_data != nullptr);
}
else
{
/// Add only if the key already exists.
/// Overflow row is disabled for JIT.
auto find_result = state.findKey(method.data, row, *aggregates_pool);
assert(find_result.getMapped() != nullptr);
aggregate_data = find_result.getMapped();
}
return aggregate_data;
};
GetAggregateDataFunction get_aggregate_data_function = FunctorToStaticMethodAdaptor<decltype(get_aggregate_data)>::unsafeCall;
GetAggregateDataContext get_aggregate_data_context = reinterpret_cast<char *>(&get_aggregate_data);
add_into_aggregate_states_function(rows, columns_data.data(), get_aggregate_data_function, get_aggregate_data_context);
}
template <bool no_more_keys, typename Method>
void NO_INLINE Aggregator::handleAggregationDefault(
Method & method,
typename Method::State & state,
Arena * aggregates_pool,
size_t rows,
AggregateFunctionInstruction * aggregate_instructions,
AggregateDataPtr overflow_row) const
{
std::unique_ptr<AggregateDataPtr[]> places(new AggregateDataPtr[rows]);
/// For all rows.
for (size_t i = 0; i < rows; ++i)
{
AggregateDataPtr aggregate_data;
if constexpr (!no_more_keys)
{
auto emplace_result = state.emplaceKey(method.data, i, *aggregates_pool);
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
if (emplace_result.isInserted())
{
/// exception-safety - if you can not allocate memory or create states, then destructors will not be called.
emplace_result.setMapped(nullptr);
aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
createAggregateStates(aggregate_data);
emplace_result.setMapped(aggregate_data);
}
else
aggregate_data = emplace_result.getMapped();
assert(aggregate_data != nullptr);
}
else
{
/// Add only if the key already exists.
auto find_result = state.findKey(method.data, i, *aggregates_pool);
if (find_result.isFound())
aggregate_data = find_result.getMapped();
else
aggregate_data = overflow_row;
}
places[i] = aggregate_data;
}
/// Add values to the aggregate functions.
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
{
if (inst->offsets)
inst->batch_that->addBatchArray(rows, places.get(), inst->state_offset, inst->batch_arguments, inst->offsets, aggregates_pool);
else
inst->batch_that->addBatch(rows, places.get(), inst->state_offset, inst->batch_arguments, aggregates_pool);
}
}
template <bool no_more_keys, typename Method>
@ -508,7 +716,7 @@ void NO_INLINE Aggregator::executeImplBatch(
return;
}
/// Optimization for special case when aggregating by 8bit key.
/// Optimization for special case when aggregating by 8bit key.`
if constexpr (!no_more_keys && std::is_same_v<Method, typename decltype(AggregatedDataVariants::key8)::element_type>)
{
/// We use another method if there are aggregate functions with -Array combinator.
@ -543,178 +751,10 @@ void NO_INLINE Aggregator::executeImplBatch(
}
}
/// Generic case.
auto get_aggregate_data = [&](size_t row) -> AggregateDataPtr
{
AggregateDataPtr aggregate_data;
if constexpr (!no_more_keys)
{
auto emplace_result = state.emplaceKey(method.data, row, *aggregates_pool);
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
if (emplace_result.isInserted())
{
/// exception-safety - if you can not allocate memory or create states, then destructors will not be called.
emplace_result.setMapped(nullptr);
aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
createAggregateStates(aggregate_data);
emplace_result.setMapped(aggregate_data);
}
else
aggregate_data = emplace_result.getMapped();
assert(aggregate_data != nullptr);
}
else
{
/// Add only if the key already exists.
auto find_result = state.findKey(method.data, row, *aggregates_pool);
if (find_result.isFound())
aggregate_data = find_result.getMapped();
else
aggregate_data = overflow_row;
}
// std::cerr << "Row " << row << " returned place " << static_cast<void *>(aggregate_data) << std::endl;
return aggregate_data;
};
#if USE_EMBEDDED_COMPILER
std::vector<ColumnData> columns_data;
std::vector<AggregateFunctionToCompile> functions_to_compile;
size_t aggregate_instructions_size = 0;
/// Add values to the aggregate functions.
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
{
const auto * function = inst->that;
if (function && function->isCompilable())
{
AggregateFunctionToCompile function_to_compile
{
.function = inst->that,
.aggregate_data_offset = inst->state_offset
};
columns_data.emplace_back(getColumnData(inst->batch_arguments[0]));
functions_to_compile.emplace_back(std::move(function_to_compile));
}
++aggregate_instructions_size;
}
if (params.compile_aggregate_expressions && functions_to_compile.size() == aggregate_instructions_size)
{
std::string functions_dump;
for (const auto & func : functions_to_compile)
{
const auto * function = func.function;
std::string function_dump;
auto return_type_name = function->getReturnType()->getName();
function_dump += return_type_name;
function_dump += ' ';
function_dump += function->getName();
function_dump += '(';
const auto & argument_types = function->getArgumentTypes();
for (const auto & argument_type : argument_types)
{
function_dump += argument_type->getName();
function_dump += ',';
}
if (!argument_types.empty())
function_dump.pop_back();
function_dump += ')';
functions_dump += function_dump;
functions_dump += ' ';
}
static std::unordered_map<std::string, CHJIT::CompiledModuleInfo> aggregation_functions_dump_to_compiled_module_info;
CHJIT::CompiledModuleInfo compiled_module;
auto it = aggregation_functions_dump_to_compiled_module_info.find(functions_dump);
if (it != aggregation_functions_dump_to_compiled_module_info.end())
{
compiled_module = it->second;
LOG_TRACE(log, "Get compiled aggregate functions {} from cache", functions_dump);
}
else
{
compiled_module = compileAggregateFunctons(getJITInstance(), functions_to_compile, functions_dump);
aggregation_functions_dump_to_compiled_module_info[functions_dump] = compiled_module;
}
LOG_TRACE(log, "Use compiled expression {}", functions_dump);
JITCompiledAggregateFunction aggregate_function = reinterpret_cast<JITCompiledAggregateFunction>(getJITInstance().findCompiledFunction(compiled_module, functions_dump));
GetAggregateDataFunction get_aggregate_data_function = FunctorToStaticMethodAdaptor<decltype(get_aggregate_data)>::unsafeCall;
GetAggregateDataContext get_aggregate_data_context = reinterpret_cast<char *>(&get_aggregate_data);
aggregate_function(rows, columns_data.data(), get_aggregate_data_function, get_aggregate_data_context);
}
if (compiled_functions)
handleAggregationJIT<no_more_keys>(method, state, aggregates_pool, rows, aggregate_instructions);
else
#endif
{
std::unique_ptr<AggregateDataPtr[]> places(new AggregateDataPtr[rows]);
/// For all rows.
for (size_t i = 0; i < rows; ++i)
{
AggregateDataPtr aggregate_data;
if constexpr (!no_more_keys)
{
auto emplace_result = state.emplaceKey(method.data, i, *aggregates_pool);
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
if (emplace_result.isInserted())
{
/// exception-safety - if you can not allocate memory or create states, then destructors will not be called.
emplace_result.setMapped(nullptr);
aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
createAggregateStates(aggregate_data);
emplace_result.setMapped(aggregate_data);
}
else
aggregate_data = emplace_result.getMapped();
assert(aggregate_data != nullptr);
}
else
{
/// Add only if the key already exists.
auto find_result = state.findKey(method.data, i, *aggregates_pool);
if (find_result.isFound())
aggregate_data = find_result.getMapped();
else
aggregate_data = overflow_row;
}
places[i] = aggregate_data;
}
/// Add values to the aggregate functions.
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
{
if (inst->offsets)
inst->batch_that->addBatchArray(rows, places.get(), inst->state_offset, inst->batch_arguments, inst->offsets, aggregates_pool);
else
inst->batch_that->addBatch(rows, places.get(), inst->state_offset, inst->batch_arguments, aggregates_pool);
}
}
handleAggregationDefault<no_more_keys>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row);
}
@ -1251,11 +1291,38 @@ void NO_INLINE Aggregator::convertToBlockImplFinal(
auto shuffled_key_sizes = method.shuffleKeyColumns(key_columns, key_sizes);
const auto & key_sizes_ref = shuffled_key_sizes ? *shuffled_key_sizes : key_sizes;
data.forEachValue([&](const auto & key, auto & mapped)
if (compiled_functions)
{
method.insertKeyIntoColumns(key, key_columns, key_sizes_ref);
insertAggregatesIntoColumns(mapped, final_aggregate_columns, arena);
});
std::unique_ptr<AggregateDataPtr[]> places(new AggregateDataPtr[data.size()]);
size_t place_index = 0;
data.forEachValue([&](const auto & key, auto & mapped)
{
method.insertKeyIntoColumns(key, key_columns, key_sizes_ref);
places[place_index] = mapped;
++place_index;
});
std::vector<ColumnData> columns_data;
columns_data.reserve(final_aggregate_columns.size());
for (auto & final_aggregate_column : final_aggregate_columns)
{
final_aggregate_column = final_aggregate_column->cloneResized(data.size());
columns_data.emplace_back(getColumnData(final_aggregate_column.get()));
}
auto insert_aggregate_states_function = compiled_functions->insert_aggregates_into_columns_function;
insert_aggregate_states_function(data.size(), columns_data.data(), places.get());
}
else
{
data.forEachValue([&](const auto & key, auto & mapped)
{
method.insertKeyIntoColumns(key, key_columns, key_sizes_ref);
insertAggregatesIntoColumns(mapped, final_aggregate_columns, arena);
});
}
}
template <typename Method, typename Table>
@ -1684,27 +1751,45 @@ void NO_INLINE Aggregator::mergeDataImpl(
if constexpr (Method::low_cardinality_optimization)
mergeDataNullKey<Method, Table>(table_dst, table_src, arena);
table_src.mergeToViaEmplace(table_dst,
[&](AggregateDataPtr & __restrict dst, AggregateDataPtr & __restrict src, bool inserted)
if (compiled_functions)
{
if (!inserted)
{
for (size_t i = 0; i < params.aggregates_size; ++i)
aggregate_functions[i]->merge(
dst + offsets_of_aggregate_states[i],
src + offsets_of_aggregate_states[i],
arena);
auto merge_aggregate_states_function_typed = compiled_functions->merge_aggregate_states_function;
for (size_t i = 0; i < params.aggregates_size; ++i)
aggregate_functions[i]->destroy(src + offsets_of_aggregate_states[i]);
}
else
table_src.mergeToViaEmplace(table_dst, [&](AggregateDataPtr & __restrict dst, AggregateDataPtr & __restrict src, bool inserted)
{
dst = src;
}
if (!inserted)
{
merge_aggregate_states_function_typed(dst, src);
}
else
{
dst = src;
}
src = nullptr;
});
}
else
{
table_src.mergeToViaEmplace(table_dst, [&](AggregateDataPtr & __restrict dst, AggregateDataPtr & __restrict src, bool inserted)
{
if (!inserted)
{
for (size_t i = 0; i < params.aggregates_size; ++i)
aggregate_functions[i]->merge(dst + offsets_of_aggregate_states[i], src + offsets_of_aggregate_states[i], arena);
for (size_t i = 0; i < params.aggregates_size; ++i)
aggregate_functions[i]->destroy(src + offsets_of_aggregate_states[i]);
}
else
{
dst = src;
}
src = nullptr;
});
}
src = nullptr;
});
table_src.clearAndShrink();
}

View File

@ -26,6 +26,7 @@
#include <Interpreters/AggregateDescription.h>
#include <Interpreters/AggregationCommon.h>
#include <Interpreters/JIT/compileFunction.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnFixedString.h>
@ -1082,6 +1083,12 @@ private:
/// For external aggregation.
TemporaryFiles temporary_files;
std::optional<CompiledAggregateFunctions> compiled_functions;
/** Try to compile aggregate functions.
*/
void compileAggregateFunctions();
/** Select the aggregation method based on the number and types of keys. */
AggregatedDataVariants::Type chooseAggregationMethod();
@ -1116,6 +1123,41 @@ private:
AggregateFunctionInstruction * aggregate_instructions,
AggregateDataPtr overflow_row) const;
template <bool no_more_keys, typename Method>
void handleAggregationJIT(
Method & method,
typename Method::State & state,
Arena * aggregates_pool,
size_t rows,
AggregateFunctionInstruction * aggregate_instructions) const;
// template <bool no_more_keys, typename Method>
// void handleAggregationJITV2(
// Method & method,
// typename Method::State & state,
// Arena * aggregates_pool,
// size_t rows,
// AggregateFunctionInstruction * aggregate_instructions,
// AggregateDataPtr overflow_row) const;
// template <bool no_more_keys, typename Method>
// void handleAggregationJITV3(
// Method & method,
// typename Method::State & state,
// Arena * aggregates_pool,
// size_t rows,
// AggregateFunctionInstruction * aggregate_instructions,
// AggregateDataPtr overflow_row) const;
template <bool no_more_keys, typename Method>
void handleAggregationDefault(
Method & method,
typename Method::State & state,
Arena * aggregates_pool,
size_t rows,
AggregateFunctionInstruction * aggregate_instructions,
AggregateDataPtr overflow_row) const;
/// For case when there are no keys (all aggregate into one row).
static void executeWithoutKeyImpl(
AggregatedDataWithoutKey & res,

View File

@ -42,36 +42,29 @@ static Poco::Logger * getLogger()
return &logger;
}
class CompiledFunction
class CompiledFunctionHolder
{
public:
CompiledFunction(void * compiled_function_, CHJIT::CompiledModuleInfo module_info_)
explicit CompiledFunctionHolder(CompiledFunction compiled_function_)
: compiled_function(compiled_function_)
, module_info(std::move(module_info_))
{}
void * getCompiledFunction() const { return compiled_function; }
~CompiledFunction()
~CompiledFunctionHolder()
{
getJITInstance().deleteCompiledModule(module_info);
getJITInstance().deleteCompiledModule(compiled_function.compiled_module);
}
private:
void * compiled_function;
CHJIT::CompiledModuleInfo module_info;
CompiledFunction compiled_function;
};
class LLVMExecutableFunction : public IExecutableFunction
{
public:
explicit LLVMExecutableFunction(const std::string & name_, std::shared_ptr<CompiledFunction> compiled_function_)
explicit LLVMExecutableFunction(const std::string & name_, std::shared_ptr<CompiledFunctionHolder> compiled_function_holder_)
: name(name_)
, compiled_function(compiled_function_)
, compiled_function_holder(compiled_function_holder_)
{
}
@ -104,8 +97,8 @@ public:
columns[arguments.size()] = getColumnData(result_column.get());
JITCompiledFunction jit_compiled_function_typed = reinterpret_cast<JITCompiledFunction>(compiled_function->getCompiledFunction());
jit_compiled_function_typed(input_rows_count, columns.data());
auto jit_compiled_function = compiled_function_holder->compiled_function.compiled_function;
jit_compiled_function(input_rows_count, columns.data());
#if defined(MEMORY_SANITIZER)
/// Memory sanitizer don't know about stores from JIT-ed code.
@ -135,7 +128,7 @@ public:
private:
std::string name;
std::shared_ptr<CompiledFunction> compiled_function;
std::shared_ptr<CompiledFunctionHolder> compiled_function_holder;
};
class LLVMFunction : public IFunctionBase
@ -157,9 +150,9 @@ public:
}
}
void setCompiledFunction(std::shared_ptr<CompiledFunction> compiled_function_)
void setCompiledFunction(std::shared_ptr<CompiledFunctionHolder> compiled_function_holder_)
{
compiled_function = compiled_function_;
compiled_function_holder = compiled_function_holder_;
}
bool isCompilable() const override { return true; }
@ -177,10 +170,10 @@ public:
ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName &) const override
{
if (!compiled_function)
if (!compiled_function_holder)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Compiled function was not initialized {}", name);
return std::make_unique<LLVMExecutableFunction>(name, compiled_function);
return std::make_unique<LLVMExecutableFunction>(name, compiled_function_holder);
}
bool isDeterministic() const override
@ -269,7 +262,7 @@ private:
CompileDAG dag;
DataTypes argument_types;
std::vector<FunctionBasePtr> nested_functions;
std::shared_ptr<CompiledFunction> compiled_function;
std::shared_ptr<CompiledFunctionHolder> compiled_function_holder;
};
static FunctionBasePtr compile(
@ -293,22 +286,20 @@ static FunctionBasePtr compile(
auto [compiled_function_cache_entry, _] = compilation_cache->getOrSet(hash_key, [&] ()
{
LOG_TRACE(getLogger(), "Compile expression {}", llvm_function->getName());
CHJIT::CompiledModuleInfo compiled_module_info = compileFunction(getJITInstance(), *llvm_function);
auto * compiled_jit_function = getJITInstance().findCompiledFunction(compiled_module_info, llvm_function->getName());
auto compiled_function = std::make_shared<CompiledFunction>(compiled_jit_function, compiled_module_info);
auto compiled_function = compileFunction(getJITInstance(), *llvm_function);
auto compiled_function_holder = std::make_shared<CompiledFunctionHolder>(compiled_function);
return std::make_shared<CompiledFunctionCacheEntry>(std::move(compiled_function), compiled_module_info.size);
return std::make_shared<CompiledFunctionCacheEntry>(std::move(compiled_function_holder), compiled_function.compiled_module.size);
});
llvm_function->setCompiledFunction(compiled_function_cache_entry->getCompiledFunction());
llvm_function->setCompiledFunction(compiled_function_cache_entry->getCompiledFunctionHolder());
}
else
{
LOG_TRACE(getLogger(), "Compile expression {}", llvm_function->getName());
CHJIT::CompiledModuleInfo compiled_module_info = compileFunction(getJITInstance(), *llvm_function);
auto * compiled_jit_function = getJITInstance().findCompiledFunction(compiled_module_info, llvm_function->getName());
auto compiled_function = std::make_shared<CompiledFunction>(compiled_jit_function, compiled_module_info);
llvm_function->setCompiledFunction(compiled_function);
auto compiled_function = compileFunction(getJITInstance(), *llvm_function);
auto compiled_function_ptr = std::make_shared<CompiledFunctionHolder>(compiled_function);
llvm_function->setCompiledFunction(compiled_function_ptr);
}
return llvm_function;

View File

@ -11,22 +11,22 @@
namespace DB
{
class CompiledFunction;
class CompiledFunctionHolder;
class CompiledFunctionCacheEntry
{
public:
CompiledFunctionCacheEntry(std::shared_ptr<CompiledFunction> compiled_function_, size_t compiled_function_size_)
: compiled_function(std::move(compiled_function_))
CompiledFunctionCacheEntry(std::shared_ptr<CompiledFunctionHolder> compiled_function_holder_, size_t compiled_function_size_)
: compiled_function_holder(std::move(compiled_function_holder_))
, compiled_function_size(compiled_function_size_)
{}
std::shared_ptr<CompiledFunction> getCompiledFunction() const { return compiled_function; }
std::shared_ptr<CompiledFunctionHolder> getCompiledFunctionHolder() const { return compiled_function_holder; }
size_t getCompiledFunctionSize() const { return compiled_function_size; }
private:
std::shared_ptr<CompiledFunction> compiled_function;
std::shared_ptr<CompiledFunctionHolder> compiled_function_holder;
size_t compiled_function_size;
};

View File

@ -189,7 +189,7 @@ CHJIT::CHJIT()
CHJIT::~CHJIT() = default;
CHJIT::CompiledModuleInfo CHJIT::compileModule(std::function<void (llvm::Module &)> compile_function)
CHJIT::CompiledModule CHJIT::compileModule(std::function<void (llvm::Module &)> compile_function)
{
std::lock_guard<std::mutex> lock(jit_lock);
@ -210,12 +210,15 @@ std::unique_ptr<llvm::Module> CHJIT::createModuleForCompilation()
return module;
}
CHJIT::CompiledModuleInfo CHJIT::compileModule(std::unique_ptr<llvm::Module> module)
CHJIT::CompiledModule CHJIT::compileModule(std::unique_ptr<llvm::Module> module)
{
runOptimizationPassesOnModule(*module);
auto buffer = compiler->compile(*module);
// llvm::errs() << "Module after optimizations " << "\n";
// module->print(llvm::errs(), nullptr);
llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> object = llvm::object::ObjectFile::createObjectFile(*buffer);
if (!object)
@ -234,7 +237,7 @@ CHJIT::CompiledModuleInfo CHJIT::compileModule(std::unique_ptr<llvm::Module> mod
dynamic_linker.resolveRelocations();
module_memory_manager->getManager().finalizeMemory();
CompiledModuleInfo module_info;
CompiledModule compiled_module;
for (const auto & function : *module)
{
@ -250,47 +253,29 @@ CHJIT::CompiledModuleInfo CHJIT::compileModule(std::unique_ptr<llvm::Module> mod
throw Exception(ErrorCodes::CANNOT_COMPILE_CODE, "DynamicLinker could not found symbol {} after compilation", function_name);
auto * jit_symbol_address = reinterpret_cast<void *>(jit_symbol.getAddress());
std::string symbol_name = std::to_string(current_module_key) + '_' + function_name;
name_to_symbol[symbol_name] = jit_symbol_address;
module_info.compiled_functions.emplace_back(std::move(function_name));
compiled_module.function_name_to_symbol.emplace(std::move(function_name), jit_symbol_address);
}
module_info.size = module_memory_manager->getAllocatedSize();
module_info.identifier = current_module_key;
compiled_module.size = module_memory_manager->getAllocatedSize();
compiled_module.identifier = current_module_key;
module_identifier_to_memory_manager[current_module_key] = std::move(module_memory_manager);
compiled_code_size.fetch_add(module_info.size, std::memory_order_relaxed);
compiled_code_size.fetch_add(compiled_module.size, std::memory_order_relaxed);
return module_info;
return compiled_module;
}
void CHJIT::deleteCompiledModule(const CHJIT::CompiledModuleInfo & module_info)
void CHJIT::deleteCompiledModule(const CHJIT::CompiledModule & module)
{
std::lock_guard<std::mutex> lock(jit_lock);
auto module_it = module_identifier_to_memory_manager.find(module_info.identifier);
auto module_it = module_identifier_to_memory_manager.find(module.identifier);
if (module_it == module_identifier_to_memory_manager.end())
throw Exception(ErrorCodes::LOGICAL_ERROR, "There is no compiled module with identifier {}", module_info.identifier);
for (const auto & function : module_info.compiled_functions)
name_to_symbol.erase(function);
throw Exception(ErrorCodes::LOGICAL_ERROR, "There is no compiled module with identifier {}", module.identifier);
module_identifier_to_memory_manager.erase(module_it);
compiled_code_size.fetch_sub(module_info.size, std::memory_order_relaxed);
}
void * CHJIT::findCompiledFunction(const CompiledModuleInfo & module_info, const std::string & function_name) const
{
std::lock_guard<std::mutex> lock(jit_lock);
std::string symbol_name = std::to_string(module_info.identifier) + '_' + function_name;
auto it = name_to_symbol.find(symbol_name);
if (it != name_to_symbol.end())
return it->second;
return nullptr;
compiled_code_size.fetch_sub(module.size, std::memory_order_relaxed);
}
void CHJIT::registerExternalSymbol(const std::string & symbol_name, void * address)

View File

@ -52,32 +52,31 @@ public:
~CHJIT();
struct CompiledModuleInfo
struct CompiledModule
{
/// Size of compiled module code in bytes
size_t size;
/// Module identifier. Should not be changed by client
uint64_t identifier;
/// Vector of compiled function nameds. Should not be changed by client
std::vector<std::string> compiled_functions;
/// Vector of compiled functions. Should not be changed by client.
/// It is client responsibility to cast result function to right signature.
/// After call to deleteCompiledModule compiled functions from module become invalid.
std::unordered_map<std::string, void *> function_name_to_symbol;
};
/** Compile module. In compile function client responsibility is to fill module with necessary
* IR code, then it will be compiled by CHJIT instance.
* Return compiled module info.
* Return compiled module.
*/
CompiledModuleInfo compileModule(std::function<void (llvm::Module &)> compile_function);
CompiledModule compileModule(std::function<void (llvm::Module &)> compile_function);
/** Delete compiled module. Pointers to functions from module become invalid after this call.
* It is client responsibility to be sure that there are no pointers to compiled module code.
*/
void deleteCompiledModule(const CompiledModuleInfo & module_info);
/** Find compiled function using module_info, and function_name.
* It is client responsibility to case result function to right signature.
* After call to deleteCompiledModule compiled functions from module become invalid.
*/
void * findCompiledFunction(const CompiledModuleInfo & module_info, const std::string & function_name) const;
void deleteCompiledModule(const CompiledModule & module_info);
/** Register external symbol for CHJIT instance to use, during linking.
* It can be function, or global constant.
@ -93,7 +92,7 @@ private:
std::unique_ptr<llvm::Module> createModuleForCompilation();
CompiledModuleInfo compileModule(std::unique_ptr<llvm::Module> module);
CompiledModule compileModule(std::unique_ptr<llvm::Module> module);
std::string getMangledName(const std::string & name_to_mangle) const;
@ -107,7 +106,6 @@ private:
std::unique_ptr<JITCompiler> compiler;
std::unique_ptr<JITSymbolResolver> symbol_resolver;
std::unordered_map<std::string, void *> name_to_symbol;
std::unordered_map<uint64_t, std::unique_ptr<JITModuleMemoryManager>> module_identifier_to_memory_manager;
uint64_t current_module_key = 0;
std::atomic<size_t> compiled_code_size = 0;

View File

@ -250,205 +250,288 @@ static void compileFunction(llvm::Module & module, const IFunctionBase & functio
b.CreateRetVoid();
}
CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBase & function)
CompiledFunction compileFunction(CHJIT & jit, const IFunctionBase & function)
{
Stopwatch watch;
auto compiled_module_info = jit.compileModule([&](llvm::Module & module)
auto compiled_module = jit.compileModule([&](llvm::Module & module)
{
compileFunction(module, function);
});
ProfileEvents::increment(ProfileEvents::CompileExpressionsMicroseconds, watch.elapsedMicroseconds());
ProfileEvents::increment(ProfileEvents::CompileExpressionsBytes, compiled_module_info.size);
ProfileEvents::increment(ProfileEvents::CompileExpressionsBytes, compiled_module.size);
ProfileEvents::increment(ProfileEvents::CompileFunction);
return compiled_module_info;
auto compiled_function_ptr = reinterpret_cast<JITCompiledFunction>(compiled_module.function_name_to_symbol[function.getName()]);
assert(compiled_function_ptr);
CompiledFunction result_compiled_function
{
.compiled_function = compiled_function_ptr,
.compiled_module = compiled_module
};
return result_compiled_function;
}
CHJIT::CompiledModuleInfo compileAggregateFunctons(CHJIT & jit, const std::vector<AggregateFunctionToCompile> & functions, const std::string & result_name)
static void compileCreateAggregateStatesFunctions(llvm::Module & module, const std::vector<AggregateFunctionWithOffset> & functions, const std::string & name)
{
auto compiled_module_info = jit.compileModule([&](llvm::Module & module)
auto & context = module.getContext();
llvm::IRBuilder<> b(context);
auto * aggregate_data_places_type = b.getInt8Ty()->getPointerTo();
auto * create_aggregate_states_function_type = llvm::FunctionType::get(b.getVoidTy(), { aggregate_data_places_type }, false);
auto * create_aggregate_states_function = llvm::Function::Create(create_aggregate_states_function_type, llvm::Function::ExternalLinkage, name, module);
auto * arguments = create_aggregate_states_function->args().begin();
llvm::Value * aggregate_data_place_arg = arguments++;
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", create_aggregate_states_function);
b.SetInsertPoint(entry);
std::vector<ColumnDataPlaceholder> columns(functions.size());
for (const auto & function_to_compile : functions)
{
auto & context = module.getContext();
llvm::IRBuilder<> b (context);
size_t aggregate_function_offset = function_to_compile.aggregate_data_offset;
const auto * aggregate_function = function_to_compile.function;
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_arg, aggregate_function_offset);
aggregate_function->compileCreate(b, aggregation_place_with_offset);
}
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
module.print(llvm::errs(), nullptr);
auto * column_data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
auto * get_place_func_declaration = llvm::FunctionType::get(b.getInt8Ty()->getPointerTo(), { b.getInt8Ty()->getPointerTo(), size_type }, /*isVarArg=*/false);
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { size_type, column_data_type->getPointerTo(), get_place_func_declaration->getPointerTo(), b.getInt8Ty()->getPointerTo() }, false);
auto * aggregate_loop_func_definition = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, result_name, module);
auto * arguments = aggregate_loop_func_definition->args().begin();
llvm::Value * rows_count_arg = &*arguments++;
llvm::Value * columns_arg = &*arguments++;
llvm::Value * get_place_function_arg = &*arguments++;
llvm::Value * get_place_function_context_arg = &*arguments++;
/// Initialize ColumnDataPlaceholder llvm representation of ColumnData
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func_definition);
b.SetInsertPoint(entry);
std::vector<ColumnDataPlaceholder> columns(functions.size());
for (size_t i = 0; i < functions.size(); ++i)
{
auto argument_type = functions[i].function->getArgumentTypes()[0];
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, i));
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(argument_type))->getPointerTo());
}
/// Initialize loop
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", aggregate_loop_func_definition);
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", aggregate_loop_func_definition);
b.CreateCondBr(b.CreateICmpEQ(rows_count_arg, llvm::ConstantInt::get(size_type, 0)), end, loop);
b.SetInsertPoint(loop);
auto * counter_phi = b.CreatePHI(rows_count_arg->getType(), 2);
counter_phi->addIncoming(llvm::ConstantInt::get(size_type, 0), entry);
for (auto & col : columns)
{
col.data = b.CreatePHI(col.data_init->getType(), 2);
col.data->addIncoming(col.data_init, entry);
}
auto * aggregation_place = b.CreateCall(get_place_func_declaration, get_place_function_arg, { get_place_function_context_arg, counter_phi });
for (size_t i = 0; i < functions.size(); ++i)
{
size_t aggregate_function_offset = functions[i].aggregate_data_offset;
const auto * aggregate_function_ptr = functions[i].function;
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregation_place, aggregate_function_offset);
auto column_type = functions[i].function->getArgumentTypes()[0];
auto * column_data = b.CreateLoad(toNativeType(b, column_type), columns[i].data);
aggregate_function_ptr->compile(b, aggregation_place_with_offset, column_type, column_data);
}
/// End of loop
auto * cur_block = b.GetInsertBlock();
for (auto & col : columns)
{
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
if (col.null)
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
}
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1));
counter_phi->addIncoming(value, loop);
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);
b.SetInsertPoint(end);
b.CreateRetVoid();
llvm::errs() << "Module before optimizations \n";
module.print(llvm::errs(), nullptr);
});
return compiled_module_info;
b.CreateRetVoid();
}
CHJIT::CompiledModuleInfo compileAggregateFunctonsV2(CHJIT & jit, const std::vector<AggregateFunctionToCompile> & functions, const std::string & result_name)
static void compileAddIntoAggregateStatesFunctions(llvm::Module & module, const std::vector<AggregateFunctionWithOffset> & functions, const std::string & name)
{
auto compiled_module_info = jit.compileModule([&](llvm::Module & module)
auto & context = module.getContext();
llvm::IRBuilder<> b(context);
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
auto * column_data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
auto * get_place_func_declaration = llvm::FunctionType::get(b.getInt8Ty()->getPointerTo(), { b.getInt8Ty()->getPointerTo(), size_type }, /*isVarArg=*/false);
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { size_type, column_data_type->getPointerTo(), get_place_func_declaration->getPointerTo(), b.getInt8Ty()->getPointerTo() }, false);
auto * aggregate_loop_func_definition = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, name, module);
auto * arguments = aggregate_loop_func_definition->args().begin();
llvm::Value * rows_count_arg = arguments++;
llvm::Value * columns_arg = arguments++;
llvm::Value * get_place_function_arg = arguments++;
llvm::Value * get_place_function_context_arg = arguments++;
/// Initialize ColumnDataPlaceholder llvm representation of ColumnData
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func_definition);
b.SetInsertPoint(entry);
std::vector<ColumnDataPlaceholder> columns(functions.size());
for (size_t i = 0; i < functions.size(); ++i)
{
auto & context = module.getContext();
llvm::IRBuilder<> b (context);
auto argument_type = functions[i].function->getArgumentTypes()[0];
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, i));
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(argument_type))->getPointerTo());
}
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
/// Initialize loop
auto * column_data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
auto * aggregate_data_places_type = b.getInt8Ty()->getPointerTo()->getPointerTo();
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { size_type, column_data_type->getPointerTo(), aggregate_data_places_type }, false);
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", aggregate_loop_func_definition);
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", aggregate_loop_func_definition);
auto * aggregate_loop_func_definition = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, result_name, module);
b.CreateCondBr(b.CreateICmpEQ(rows_count_arg, llvm::ConstantInt::get(size_type, 0)), end, loop);
auto * arguments = aggregate_loop_func_definition->args().begin();
llvm::Value * rows_count_arg = &*arguments++;
llvm::Value * columns_arg = &*arguments++;
llvm::Value * aggregate_data_places_arg = &*arguments++;
b.SetInsertPoint(loop);
/// Initialize ColumnDataPlaceholder llvm representation of ColumnData
auto * counter_phi = b.CreatePHI(rows_count_arg->getType(), 2);
counter_phi->addIncoming(llvm::ConstantInt::get(size_type, 0), entry);
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func_definition);
b.SetInsertPoint(entry);
for (auto & col : columns)
{
col.data = b.CreatePHI(col.data_init->getType(), 2);
col.data->addIncoming(col.data_init, entry);
}
std::vector<ColumnDataPlaceholder> columns(functions.size());
for (size_t i = 0; i < functions.size(); ++i)
{
auto argument_type = functions[i].function->getArgumentTypes()[0];
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, i));
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(argument_type))->getPointerTo());
}
auto * aggregation_place = b.CreateCall(get_place_func_declaration, get_place_function_arg, { get_place_function_context_arg, counter_phi });
/// Initialize loop
for (size_t i = 0; i < functions.size(); ++i)
{
size_t aggregate_function_offset = functions[i].aggregate_data_offset;
const auto * aggregate_function_ptr = functions[i].function;
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", aggregate_loop_func_definition);
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", aggregate_loop_func_definition);
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregation_place, aggregate_function_offset);
b.CreateCondBr(b.CreateICmpEQ(rows_count_arg, llvm::ConstantInt::get(size_type, 0)), end, loop);
auto column_type = functions[i].function->getArgumentTypes()[0];
auto * column_data = b.CreateLoad(toNativeType(b, column_type), columns[i].data);
aggregate_function_ptr->compileAdd(b, aggregation_place_with_offset, column_type, column_data);
}
b.SetInsertPoint(loop);
/// End of loop
auto * counter_phi = b.CreatePHI(rows_count_arg->getType(), 2);
counter_phi->addIncoming(llvm::ConstantInt::get(size_type, 0), entry);
auto * cur_block = b.GetInsertBlock();
for (auto & col : columns)
{
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
if (col.null)
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
}
auto * aggregate_data_place_phi = b.CreatePHI(aggregate_data_places_type, 2);
aggregate_data_place_phi->addIncoming(aggregate_data_places_arg, entry);
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1));
counter_phi->addIncoming(value, loop);
for (auto & col : columns)
{
col.data = b.CreatePHI(col.data_init->getType(), 2);
col.data->addIncoming(col.data_init, entry);
}
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);
for (size_t i = 0; i < functions.size(); ++i)
{
size_t aggregate_function_offset = functions[i].aggregate_data_offset;
const auto * aggregate_function_ptr = functions[i].function;
b.SetInsertPoint(end);
b.CreateRetVoid();
}
auto * aggregate_data_place = b.CreateLoad(b.getInt8Ty()->getPointerTo(), aggregate_data_place_phi);
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place, aggregate_function_offset);
static void compileMergeAggregatesStates(llvm::Module & module, const std::vector<AggregateFunctionWithOffset> & functions, const std::string & name)
{
auto & context = module.getContext();
llvm::IRBuilder<> b(context);
auto column_type = functions[i].function->getArgumentTypes()[0];
auto * column_data = b.CreateLoad(toNativeType(b, column_type), columns[i].data);
aggregate_function_ptr->compile(b, aggregation_place_with_offset, column_type, column_data);
}
auto * aggregate_data_places_type = b.getInt8Ty()->getPointerTo();
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { aggregate_data_places_type, aggregate_data_places_type }, false);
auto * aggregate_loop_func = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, name, module);
/// End of loop
auto * arguments = aggregate_loop_func->args().begin();
llvm::Value * aggregate_data_place_dst_arg = arguments++;
llvm::Value * aggregate_data_place_src_arg = arguments++;
auto * cur_block = b.GetInsertBlock();
for (auto & col : columns)
{
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
if (col.null)
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
}
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func);
b.SetInsertPoint(entry);
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1), "", true, true);
counter_phi->addIncoming(value, loop);
for (const auto & function_to_compile : functions)
{
size_t aggregate_function_offset = function_to_compile.aggregate_data_offset;
const auto * aggregate_function_ptr = function_to_compile.function;
aggregate_data_place_phi->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_phi, 1), loop);
auto * aggregate_data_place_merge_dst_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_dst_arg, aggregate_function_offset);
auto * aggregate_data_place_merge_src_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_src_arg, aggregate_function_offset);
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);
aggregate_function_ptr->compileMerge(b, aggregate_data_place_merge_dst_with_offset, aggregate_data_place_merge_src_with_offset);
}
b.SetInsertPoint(end);
b.CreateRetVoid();
b.CreateRetVoid();
}
llvm::errs() << "Module before optimizations \n";
module.print(llvm::errs(), nullptr);
static void compileInsertAggregatesIntoResultColumns(llvm::Module & module, const std::vector<AggregateFunctionWithOffset> & functions, const std::string & name)
{
auto & context = module.getContext();
llvm::IRBuilder<> b(context);
auto * size_type = b.getIntNTy(sizeof(size_t) * 8);
auto * column_data_type = llvm::StructType::get(b.getInt8PtrTy(), b.getInt8PtrTy());
auto * aggregate_data_places_type = b.getInt8Ty()->getPointerTo()->getPointerTo();
auto * aggregate_loop_func_declaration = llvm::FunctionType::get(b.getVoidTy(), { size_type, column_data_type->getPointerTo(), aggregate_data_places_type }, false);
auto * aggregate_loop_func = llvm::Function::Create(aggregate_loop_func_declaration, llvm::Function::ExternalLinkage, name, module);
auto * arguments = aggregate_loop_func->args().begin();
llvm::Value * rows_count_arg = &*arguments++;
llvm::Value * columns_arg = &*arguments++;
llvm::Value * aggregate_data_places_arg = &*arguments++;
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func);
b.SetInsertPoint(entry);
std::vector<ColumnDataPlaceholder> columns(functions.size());
for (size_t i = 0; i < functions.size(); ++i)
{
auto return_type = functions[i].function->getReturnType();
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, i));
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(return_type))->getPointerTo());
}
auto * end = llvm::BasicBlock::Create(b.getContext(), "end", aggregate_loop_func);
auto * loop = llvm::BasicBlock::Create(b.getContext(), "loop", aggregate_loop_func);
b.CreateCondBr(b.CreateICmpEQ(rows_count_arg, llvm::ConstantInt::get(size_type, 0)), end, loop);
b.SetInsertPoint(loop);
auto * counter_phi = b.CreatePHI(rows_count_arg->getType(), 2);
counter_phi->addIncoming(llvm::ConstantInt::get(size_type, 0), entry);
auto * aggregate_data_place_phi = b.CreatePHI(aggregate_data_places_type, 2);
aggregate_data_place_phi->addIncoming(aggregate_data_places_arg, entry);
for (auto & col : columns)
{
col.data = b.CreatePHI(col.data_init->getType(), 2);
col.data->addIncoming(col.data_init, entry);
}
for (size_t i = 0; i < functions.size(); ++i)
{
size_t aggregate_function_offset = functions[i].aggregate_data_offset;
const auto * aggregate_function_ptr = functions[i].function;
auto * aggregate_data_place = b.CreateLoad(b.getInt8Ty()->getPointerTo(), aggregate_data_place_phi);
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place, aggregate_function_offset);
auto column_type = functions[i].function->getArgumentTypes()[0];
auto * final_value = aggregate_function_ptr->compileGetResult(b, aggregation_place_with_offset);
b.CreateStore(final_value, columns[i].data);
}
/// End of loop
auto * cur_block = b.GetInsertBlock();
for (auto & col : columns)
{
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
if (col.null)
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
}
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1), "", true, true);
counter_phi->addIncoming(value, loop);
aggregate_data_place_phi->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, aggregate_data_place_phi, 1), loop);
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);
b.SetInsertPoint(end);
b.CreateRetVoid();
}
CompiledAggregateFunctions compileAggregateFunctons(CHJIT & jit, const std::vector<AggregateFunctionWithOffset> & functions, std::string functions_dump_name)
{
std::string create_aggregate_states_functions_name = functions_dump_name + "_create";
std::string add_aggregate_states_functions_name = functions_dump_name + "_add";
std::string merge_aggregate_states_functions_name = functions_dump_name + "_merge";
std::string insert_aggregate_states_functions_name = functions_dump_name + "_insert";
auto compiled_module = jit.compileModule([&](llvm::Module & module)
{
compileCreateAggregateStatesFunctions(module, functions, create_aggregate_states_functions_name);
compileAddIntoAggregateStatesFunctions(module, functions, add_aggregate_states_functions_name);
compileMergeAggregatesStates(module, functions, merge_aggregate_states_functions_name);
compileInsertAggregatesIntoResultColumns(module, functions, insert_aggregate_states_functions_name);
});
return compiled_module_info;
auto create_aggregate_states_function = reinterpret_cast<JITCreateAggregateStatesFunction>(compiled_module.function_name_to_symbol[create_aggregate_states_functions_name]);
auto add_into_aggregate_states_function = reinterpret_cast<JITAddIntoAggregateStatesFunction>(compiled_module.function_name_to_symbol[add_aggregate_states_functions_name]);
auto merge_aggregate_states_function = reinterpret_cast<JITMergeAggregateStatesFunction>(compiled_module.function_name_to_symbol[merge_aggregate_states_functions_name]);
auto insert_aggregate_states_function = reinterpret_cast<JITInsertAggregatesIntoColumnsFunction>(compiled_module.function_name_to_symbol[insert_aggregate_states_functions_name]);
assert(create_aggregate_states_function);
assert(add_into_aggregate_states_function);
assert(merge_aggregate_states_function);
assert(insert_aggregate_states_function);
CompiledAggregateFunctions compiled_aggregate_functions
{
.create_aggregate_states_function = create_aggregate_states_function,
.add_into_aggregate_states_function = add_into_aggregate_states_function,
.merge_aggregate_states_function = merge_aggregate_states_function,
.insert_aggregates_into_columns_function = insert_aggregate_states_function
};
return compiled_aggregate_functions;
}
}

View File

@ -32,6 +32,14 @@ using ColumnDataRowsSize = size_t;
using JITCompiledFunction = void (*)(ColumnDataRowsSize, ColumnData *);
struct CompiledFunction
{
JITCompiledFunction compiled_function;
CHJIT::CompiledModule compiled_module;
};
/** Compile function to native jit code using CHJIT instance.
* Function is compiled as single module.
* After this function execution, code for function will be compiled and can be queried using
@ -41,22 +49,33 @@ using JITCompiledFunction = void (*)(ColumnDataRowsSize, ColumnData *);
* It is important that ColumnData parameter of JITCompiledFunction is result column,
* and will be filled by compiled function.
*/
CHJIT::CompiledModuleInfo compileFunction(CHJIT & jit, const IFunctionBase & function);
CompiledFunction compileFunction(CHJIT & jit, const IFunctionBase & function);
using GetAggregateDataContext = char *;
using GetAggregateDataFunction = AggregateDataPtr (*)(GetAggregateDataContext, size_t);
using JITCompiledAggregateFunction = void (*)(ColumnDataRowsSize, ColumnData *, GetAggregateDataFunction, GetAggregateDataContext);
struct AggregateFunctionToCompile
struct AggregateFunctionWithOffset
{
const IAggregateFunction * function;
size_t aggregate_data_offset;
};
CHJIT::CompiledModuleInfo compileAggregateFunctons(CHJIT & jit, const std::vector<AggregateFunctionToCompile> & functions, const std::string & result_name);
using GetAggregateDataContext = char *;
using GetAggregateDataFunction = AggregateDataPtr (*)(GetAggregateDataContext, size_t);
using JITCompiledAggregateFunctionV2 = void (*)(ColumnDataRowsSize, ColumnData *, AggregateDataPtr *);
CHJIT::CompiledModuleInfo compileAggregateFunctonsV2(CHJIT & jit, const std::vector<AggregateFunctionToCompile> & functions, const std::string & result_name);
using JITCreateAggregateStatesFunction = void (*)(AggregateDataPtr);
using JITAddIntoAggregateStatesFunction = void (*)(ColumnDataRowsSize, ColumnData *, GetAggregateDataFunction, GetAggregateDataContext);
using JITMergeAggregateStatesFunction = void (*)(AggregateDataPtr, AggregateDataPtr);
using JITInsertAggregatesIntoColumnsFunction = void (*)(ColumnDataRowsSize, ColumnData *, AggregateDataPtr *);
struct CompiledAggregateFunctions
{
JITCreateAggregateStatesFunction create_aggregate_states_function;
JITAddIntoAggregateStatesFunction add_into_aggregate_states_function;
JITMergeAggregateStatesFunction merge_aggregate_states_function;
JITInsertAggregatesIntoColumnsFunction insert_aggregates_into_columns_function;
CHJIT::CompiledModule compiled_module;
};
CompiledAggregateFunctions compileAggregateFunctons(CHJIT & jit, const std::vector<AggregateFunctionWithOffset> & functions, std::string functions_dump_name);
}

View File

@ -18,7 +18,7 @@ int main(int argc, char **argv)
jit.registerExternalSymbol("test_function", reinterpret_cast<void *>(&test_function));
auto compiled_module_info = jit.compileModule([](llvm::Module & module)
auto compiled_module = jit.compileModule([](llvm::Module & module)
{
auto & context = module.getContext();
llvm::IRBuilder<> b (context);
@ -43,13 +43,14 @@ int main(int argc, char **argv)
b.CreateRet(value);
});
for (const auto & compiled_function_name : compiled_module_info.compiled_functions)
for (const auto & [compiled_function_name, _] : compiled_module.function_name_to_symbol)
{
std::cerr << compiled_function_name << std::endl;
}
int64_t value = 5;
auto * test_name_function = reinterpret_cast<int64_t (*)(int64_t *)>(jit.findCompiledFunction(compiled_module_info, "test_name"));
auto * symbol = compiled_module.function_name_to_symbol["test_name"];
auto * test_name_function = reinterpret_cast<int64_t (*)(int64_t *)>(symbol);
auto result = test_name_function(&value);
std::cerr << "Result " << result << std::endl;