Compile AggregateFunctionIf

This commit is contained in:
Maksim Kita 2021-06-04 13:43:11 +03:00
parent 9b71b1040a
commit a5ef0067b8
7 changed files with 235 additions and 22 deletions

View File

@ -5,6 +5,14 @@
#include <Common/assert_cast.h>
#include <AggregateFunctions/IAggregateFunction.h>
#if !defined(ARCADIA_BUILD)
# include <Common/config.h>
#endif
#if USE_EMBEDDED_COMPILER
# include <llvm/IR/IRBuilder.h>
# include <DataTypes/Native.h>
#endif
namespace DB
{
@ -154,6 +162,76 @@ public:
const Array & params, const AggregateFunctionProperties & properties) const override;
AggregateFunctionPtr getNestedFunction() const override { return nested_func; }
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
return nested_func->isCompilable();
}
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
nested_func->compileCreate(builder, aggregate_data_ptr);
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
const auto & predicate_type = arguments_types[argument_values.size() - 1];
auto * predicate_value = argument_values[argument_values.size() - 1];
auto * head = b.GetInsertBlock();
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
auto * if_true = llvm::BasicBlock::Create(head->getContext(), "if_true", head->getParent());
auto * if_false = llvm::BasicBlock::Create(head->getContext(), "if_false", head->getParent());
auto * is_predicate_true = nativeBoolCast(b, predicate_type, predicate_value);
b.CreateCondBr(is_predicate_true, if_true, if_false);
b.SetInsertPoint(if_true);
size_t arguments_size = arguments_types.size();
DataTypes argument_types_without_predicate;
std::vector<llvm::Value *> argument_values_without_predicate;
argument_types_without_predicate.resize(arguments_size - 1);
argument_values_without_predicate.resize(arguments_size - 1);
for (size_t i = 0; i < arguments_types.size() - 1; ++i)
{
argument_types_without_predicate[i] = arguments_types[i];
argument_values_without_predicate[i] = argument_values[i];
}
nested_func->compileAdd(builder, aggregate_data_ptr, argument_types_without_predicate, argument_values_without_predicate);
b.CreateBr(join_block);
b.SetInsertPoint(if_false);
b.CreateBr(join_block);
b.SetInsertPoint(join_block);
}
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
{
nested_func->compileMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
}
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
{
return nested_func->compileGetResult(builder, aggregate_data_ptr);
}
#endif
};
}

View File

@ -393,7 +393,7 @@ public:
column.getData().push_back(this->data(place).get());
}
#if USE_EMBEDDED_COMPILER
#if USE_EMBEDDED_COMPILER
bool isCompilable() const override
{
@ -415,7 +415,7 @@ public:
b.CreateStore(llvm::ConstantInt::get(return_type, 0), aggregate_sum_ptr);
}
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypePtr & value_type, llvm::Value * value) const override
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
{
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
@ -424,7 +424,10 @@ public:
auto * sum_value_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
auto * sum_value = b.CreateLoad(return_type, sum_value_ptr);
auto * value_cast_to_result = nativeCast(b, value_type, value, return_type);
const auto & argument_type = arguments_types[0];
const auto & argument_value = argument_values[0];
auto * value_cast_to_result = nativeCast(b, argument_type, argument_value, return_type);
auto * sum_result_value = sum_value->getType()->isIntegerTy() ? b.CreateAdd(sum_value, value_cast_to_result) : b.CreateFAdd(sum_value, value_cast_to_result);
b.CreateStore(sum_result_value, sum_value_ptr);
@ -456,7 +459,7 @@ public:
return b.CreateLoad(return_type, sum_value_ptr);
}
#endif
#endif
private:
UInt32 scale;

View File

@ -260,7 +260,7 @@ public:
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
}
virtual void compileAdd(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/, const DataTypePtr & /*value_type*/, llvm::Value * /*value*/) const
virtual void compileAdd(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/, const DataTypes & /*arguments_types*/, const std::vector<llvm::Value *> & /*arguments_values*/) const
{
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
}

View File

@ -214,6 +214,8 @@ void Aggregator::Params::explain(JSONBuilder::JSONMap & map) const
}
}
#if USE_EMBEDDED_COMPILER
static CHJIT & getJITInstance()
{
static CHJIT jit;
@ -246,6 +248,8 @@ static std::string dumpAggregateFunction(const IAggregateFunction * function)
return function_dump;
}
#endif
Aggregator::Aggregator(const Params & params_)
: params(params_)
{
@ -297,13 +301,18 @@ Aggregator::Aggregator(const Params & params_)
HashMethodContext::Settings cache_settings;
cache_settings.max_threads = params.max_threads;
aggregation_state_cache = AggregatedDataVariants::createCache(method_chosen, cache_settings);
#if USE_EMBEDDED_COMPILER
compileAggregateFunctions();
#endif
}
#if USE_EMBEDDED_COMPILER
void Aggregator::compileAggregateFunctions()
{
if (!params.compile_aggregate_expressions ||
params.overflow_row)
if (!params.compile_aggregate_expressions || params.overflow_row)
return;
std::vector<AggregateFunctionWithOffset> functions_to_compile;
@ -334,7 +343,7 @@ void Aggregator::compileAggregateFunctions()
++aggregate_instructions_size;
}
if (functions_to_compile.size() != aggregate_instructions_size)
if (functions_to_compile.empty() || functions_to_compile.size() != aggregate_instructions_size)
return;
CompiledAggregateFunctions compiled_aggregate_functions;
@ -362,6 +371,8 @@ void Aggregator::compileAggregateFunctions()
compiled_functions.emplace(std::move(compiled_aggregate_functions));
}
#endif
AggregatedDataVariants::Type Aggregator::chooseAggregationMethod()
{
/// If no keys. All aggregating to single row.
@ -574,6 +585,8 @@ void NO_INLINE Aggregator::executeImpl(
executeImplBatch<true>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row);
}
#if USE_EMBEDDED_COMPILER
template <bool no_more_keys, typename Method>
void NO_INLINE Aggregator::handleAggregationJIT(
Method & method,
@ -587,7 +600,11 @@ void NO_INLINE Aggregator::handleAggregationJIT(
/// Add values to the aggregate functions.
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
columns_data.emplace_back(getColumnData(inst->batch_arguments[0]));
{
size_t arguments_size = inst->that->getArgumentTypes().size();
for (size_t i = 0; i < arguments_size; ++i)
columns_data.emplace_back(getColumnData(inst->batch_arguments[i]));
}
auto add_into_aggregate_states_function = compiled_functions->add_into_aggregate_states_function;
auto create_aggregate_states_function = compiled_functions->create_aggregate_states_function;
@ -635,6 +652,8 @@ void NO_INLINE Aggregator::handleAggregationJIT(
add_into_aggregate_states_function(rows, columns_data.data(), get_aggregate_data_function, get_aggregate_data_context);
}
#endif
template <bool no_more_keys, typename Method>
void NO_INLINE Aggregator::handleAggregationDefault(
Method & method,
@ -751,10 +770,16 @@ void NO_INLINE Aggregator::executeImplBatch(
}
}
#if USE_EMBEDDED_COMPILER
if (compiled_functions)
{
handleAggregationJIT<no_more_keys>(method, state, aggregates_pool, rows, aggregate_instructions);
}
else
#endif
{
handleAggregationDefault<no_more_keys>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row);
}
}
@ -857,6 +882,39 @@ bool Aggregator::executeOnBlock(const Block & block, AggregatedDataVariants & re
bool Aggregator::executeOnBlock(Columns columns, UInt64 num_rows, AggregatedDataVariants & result,
ColumnRawPtrs & key_columns, AggregateColumns & aggregate_columns, bool & no_more_keys)
{
// std::cerr << "Aggregator::executeOnBlock" << std::endl;
// std::cerr << "Columns " << columns.size() << std::endl;
// for (const auto & column : columns)
// {
// if (column)
// std::cerr << column->dumpStructure() << "\n";
// }
// std::cerr << "Num rows " << num_rows << std::endl;
// std::cerr << "Key columns before " << key_columns.size() << std::endl;
// for (const auto & column : key_columns)
// {
// if (column)
// std::cerr << column->dumpStructure() << "\n";
// }
// std::cerr << "Aggregate columns before " << aggregate_columns.size() << std::endl;
// for (size_t i = 0; i < aggregate_columns.size(); ++i)
// {
// const auto & aggregate_function_columns = aggregate_columns[i];
// for (const auto & aggregate_function_column : aggregate_function_columns)
// {
// if (aggregate_function_column)
// {
// std::cerr << "Aggregate function column " << static_cast<const void *>(aggregate_function_column) << std::endl;
// std::cerr << aggregate_function_column->dumpStructure() << "\n";
// }
// }
// }
// std::cerr << "No more keys " << no_more_keys << std::endl;
/// `result` will destroy the states of aggregate functions in the destructor
result.aggregator = this;
@ -890,6 +948,7 @@ bool Aggregator::executeOnBlock(Columns columns, UInt64 num_rows, AggregatedData
}
}
}
NestedColumnsHolder nested_columns_holder;
AggregateFunctionInstructions aggregate_functions_instructions;
prepareAggregateInstructions(columns, aggregate_columns, materialized_columns, aggregate_functions_instructions, nested_columns_holder);
@ -901,6 +960,28 @@ bool Aggregator::executeOnBlock(Columns columns, UInt64 num_rows, AggregatedData
result.without_key = place;
}
// std::cerr << "Key columns after " << key_columns.size() << std::endl;
// for (const auto & column : key_columns)
// {
// if (column)
// std::cerr << column->dumpStructure() << "\n";
// }
// std::cerr << "Aggregate columns after " << aggregate_columns.size() << std::endl;
// for (size_t i = 0; i < aggregate_columns.size(); ++i)
// {
// const auto & aggregate_function_columns = aggregate_columns[i];
// for (const auto & aggregate_function_column : aggregate_function_columns)
// {
// if (aggregate_function_column)
// {
// std::cerr << "Aggregate function column " << static_cast<const void *>(aggregate_function_column) << std::endl;
// std::cerr << aggregate_function_column->dumpStructure() << "\n";
// }
// }
// }
/// We select one of the aggregation methods and call it.
/// For the case when there are no keys (all aggregate into one row).
@ -1291,6 +1372,7 @@ void NO_INLINE Aggregator::convertToBlockImplFinal(
auto shuffled_key_sizes = method.shuffleKeyColumns(key_columns, key_sizes);
const auto & key_sizes_ref = shuffled_key_sizes ? *shuffled_key_sizes : key_sizes;
#if USE_EMBEDDED_COMPILER
if (compiled_functions)
{
std::unique_ptr<AggregateDataPtr[]> places(new AggregateDataPtr[data.size()]);
@ -1316,6 +1398,7 @@ void NO_INLINE Aggregator::convertToBlockImplFinal(
insert_aggregate_states_function(data.size(), columns_data.data(), places.get());
}
else
#endif
{
data.forEachValue([&](const auto & key, auto & mapped)
{
@ -1751,6 +1834,7 @@ void NO_INLINE Aggregator::mergeDataImpl(
if constexpr (Method::low_cardinality_optimization)
mergeDataNullKey<Method, Table>(table_dst, table_src, arena);
#if USE_EMBEDDED_COMPILER
if (compiled_functions)
{
auto merge_aggregate_states_function_typed = compiled_functions->merge_aggregate_states_function;
@ -1770,6 +1854,7 @@ void NO_INLINE Aggregator::mergeDataImpl(
});
}
else
#endif
{
table_src.mergeToViaEmplace(table_dst, [&](AggregateDataPtr & __restrict dst, AggregateDataPtr & __restrict src, bool inserted)
{

View File

@ -1083,7 +1083,9 @@ private:
/// For external aggregation.
TemporaryFiles temporary_files;
#if USE_EMBEDDED_COMPILER
std::optional<CompiledAggregateFunctions> compiled_functions;
#endif
/** Try to compile aggregate functions.
*/

View File

@ -299,8 +299,6 @@ static void compileCreateAggregateStatesFunctions(llvm::Module & module, const s
aggregate_function->compileCreate(b, aggregation_place_with_offset);
}
module.print(llvm::errs(), nullptr);
b.CreateRetVoid();
}
@ -328,12 +326,29 @@ static void compileAddIntoAggregateStatesFunctions(llvm::Module & module, const
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func_definition);
b.SetInsertPoint(entry);
std::vector<ColumnDataPlaceholder> columns(functions.size());
std::vector<ColumnDataPlaceholder> columns;
size_t previous_columns_size = 0;
for (size_t i = 0; i < functions.size(); ++i)
{
auto argument_type = functions[i].function->getArgumentTypes()[0];
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, i));
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(argument_type))->getPointerTo());
auto argument_types = functions[i].function->getArgumentTypes();
ColumnDataPlaceholder data_placeholder;
std::cerr << "Function " << functions[i].function->getName() << std::endl;
size_t function_arguments_size = argument_types.size();
for (size_t column_argument_index = 0; column_argument_index < function_arguments_size; ++column_argument_index)
{
const auto & argument_type = argument_types[previous_columns_size + column_argument_index];
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, column_argument_index));
std::cerr << "Argument type " << argument_type->getName() << std::endl;
data_placeholder.data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(argument_type))->getPointerTo());
columns.emplace_back(data_placeholder);
}
previous_columns_size += function_arguments_size;
}
/// Initialize loop
@ -356,16 +371,28 @@ static void compileAddIntoAggregateStatesFunctions(llvm::Module & module, const
auto * aggregation_place = b.CreateCall(get_place_func_declaration, get_place_function_arg, { get_place_function_context_arg, counter_phi });
for (size_t i = 0; i < functions.size(); ++i)
previous_columns_size = 0;
for (const auto & function : functions)
{
size_t aggregate_function_offset = functions[i].aggregate_data_offset;
const auto * aggregate_function_ptr = functions[i].function;
size_t aggregate_function_offset = function.aggregate_data_offset;
const auto * aggregate_function_ptr = function.function;
auto arguments_types = function.function->getArgumentTypes();
std::vector<llvm::Value *> arguments_values;
size_t function_arguments_size = arguments_types.size();
arguments_values.resize(function_arguments_size);
for (size_t column_argument_index = 0; column_argument_index < function_arguments_size; ++column_argument_index)
{
auto * column_argument_data = columns[previous_columns_size + column_argument_index].data;
arguments_values[column_argument_index] = b.CreateLoad(toNativeType(b, arguments_types[column_argument_index]), column_argument_data);
}
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregation_place, aggregate_function_offset);
aggregate_function_ptr->compileAdd(b, aggregation_place_with_offset, arguments_types, arguments_values);
auto column_type = functions[i].function->getArgumentTypes()[0];
auto * column_data = b.CreateLoad(toNativeType(b, column_type), columns[i].data);
aggregate_function_ptr->compileAdd(b, aggregation_place_with_offset, column_type, column_data);
previous_columns_size += function_arguments_size;
}
/// End of loop
@ -374,12 +401,13 @@ static void compileAddIntoAggregateStatesFunctions(llvm::Module & module, const
for (auto & col : columns)
{
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
if (col.null)
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
}
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1));
counter_phi->addIncoming(value, loop);
counter_phi->addIncoming(value, cur_block);
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);

View File

@ -1,5 +1,11 @@
#include <iostream>
#if !defined(ARCADIA_BUILD)
# include "config_core.h"
#endif
#if USE_EMBEDDED_COMPILER
#include <llvm/IR/IRBuilder.h>
#include <Interpreters/JIT/CHJIT.h>
@ -56,3 +62,14 @@ int main(int argc, char **argv)
return 0;
}
#else
int main(int argc, char **argv)
{
(void)(argc);
(void)(argv);
return 0;
}
#endif