mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-17 13:13:36 +00:00
Compile AggregateFunctionIf
This commit is contained in:
parent
9b71b1040a
commit
a5ef0067b8
@ -5,6 +5,14 @@
|
||||
#include <Common/assert_cast.h>
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#if !defined(ARCADIA_BUILD)
|
||||
# include <Common/config.h>
|
||||
#endif
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
# include <llvm/IR/IRBuilder.h>
|
||||
# include <DataTypes/Native.h>
|
||||
#endif
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -154,6 +162,76 @@ public:
|
||||
const Array & params, const AggregateFunctionProperties & properties) const override;
|
||||
|
||||
AggregateFunctionPtr getNestedFunction() const override { return nested_func; }
|
||||
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
bool isCompilable() const override
|
||||
{
|
||||
return nested_func->isCompilable();
|
||||
}
|
||||
|
||||
void compileCreate(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
|
||||
{
|
||||
nested_func->compileCreate(builder, aggregate_data_ptr);
|
||||
}
|
||||
|
||||
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
const auto & predicate_type = arguments_types[argument_values.size() - 1];
|
||||
auto * predicate_value = argument_values[argument_values.size() - 1];
|
||||
|
||||
auto * head = b.GetInsertBlock();
|
||||
|
||||
auto * join_block = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent());
|
||||
auto * if_true = llvm::BasicBlock::Create(head->getContext(), "if_true", head->getParent());
|
||||
auto * if_false = llvm::BasicBlock::Create(head->getContext(), "if_false", head->getParent());
|
||||
|
||||
auto * is_predicate_true = nativeBoolCast(b, predicate_type, predicate_value);
|
||||
|
||||
b.CreateCondBr(is_predicate_true, if_true, if_false);
|
||||
|
||||
b.SetInsertPoint(if_true);
|
||||
|
||||
size_t arguments_size = arguments_types.size();
|
||||
|
||||
DataTypes argument_types_without_predicate;
|
||||
std::vector<llvm::Value *> argument_values_without_predicate;
|
||||
|
||||
argument_types_without_predicate.resize(arguments_size - 1);
|
||||
argument_values_without_predicate.resize(arguments_size - 1);
|
||||
|
||||
for (size_t i = 0; i < arguments_types.size() - 1; ++i)
|
||||
{
|
||||
argument_types_without_predicate[i] = arguments_types[i];
|
||||
argument_values_without_predicate[i] = argument_values[i];
|
||||
}
|
||||
|
||||
nested_func->compileAdd(builder, aggregate_data_ptr, argument_types_without_predicate, argument_values_without_predicate);
|
||||
|
||||
b.CreateBr(join_block);
|
||||
|
||||
b.SetInsertPoint(if_false);
|
||||
b.CreateBr(join_block);
|
||||
|
||||
b.SetInsertPoint(join_block);
|
||||
}
|
||||
|
||||
void compileMerge(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_dst_ptr, llvm::Value * aggregate_data_src_ptr) const override
|
||||
{
|
||||
nested_func->compileMerge(builder, aggregate_data_dst_ptr, aggregate_data_src_ptr);
|
||||
}
|
||||
|
||||
llvm::Value * compileGetResult(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr) const override
|
||||
{
|
||||
return nested_func->compileGetResult(builder, aggregate_data_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -393,7 +393,7 @@ public:
|
||||
column.getData().push_back(this->data(place).get());
|
||||
}
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
bool isCompilable() const override
|
||||
{
|
||||
@ -415,7 +415,7 @@ public:
|
||||
b.CreateStore(llvm::ConstantInt::get(return_type, 0), aggregate_sum_ptr);
|
||||
}
|
||||
|
||||
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypePtr & value_type, llvm::Value * value) const override
|
||||
void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector<llvm::Value *> & argument_values) const override
|
||||
{
|
||||
llvm::IRBuilder<> & b = static_cast<llvm::IRBuilder<> &>(builder);
|
||||
|
||||
@ -424,7 +424,10 @@ public:
|
||||
auto * sum_value_ptr = b.CreatePointerCast(aggregate_data_ptr, return_type->getPointerTo());
|
||||
auto * sum_value = b.CreateLoad(return_type, sum_value_ptr);
|
||||
|
||||
auto * value_cast_to_result = nativeCast(b, value_type, value, return_type);
|
||||
const auto & argument_type = arguments_types[0];
|
||||
const auto & argument_value = argument_values[0];
|
||||
|
||||
auto * value_cast_to_result = nativeCast(b, argument_type, argument_value, return_type);
|
||||
auto * sum_result_value = sum_value->getType()->isIntegerTy() ? b.CreateAdd(sum_value, value_cast_to_result) : b.CreateFAdd(sum_value, value_cast_to_result);
|
||||
|
||||
b.CreateStore(sum_result_value, sum_value_ptr);
|
||||
@ -456,7 +459,7 @@ public:
|
||||
return b.CreateLoad(return_type, sum_value_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
private:
|
||||
UInt32 scale;
|
||||
|
@ -260,7 +260,7 @@ public:
|
||||
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
virtual void compileAdd(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/, const DataTypePtr & /*value_type*/, llvm::Value * /*value*/) const
|
||||
virtual void compileAdd(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/, const DataTypes & /*arguments_types*/, const std::vector<llvm::Value *> & /*arguments_values*/) const
|
||||
{
|
||||
throw Exception(getName() + " is not JIT-compilable", ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
@ -214,6 +214,8 @@ void Aggregator::Params::explain(JSONBuilder::JSONMap & map) const
|
||||
}
|
||||
}
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
static CHJIT & getJITInstance()
|
||||
{
|
||||
static CHJIT jit;
|
||||
@ -246,6 +248,8 @@ static std::string dumpAggregateFunction(const IAggregateFunction * function)
|
||||
return function_dump;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Aggregator::Aggregator(const Params & params_)
|
||||
: params(params_)
|
||||
{
|
||||
@ -297,13 +301,18 @@ Aggregator::Aggregator(const Params & params_)
|
||||
HashMethodContext::Settings cache_settings;
|
||||
cache_settings.max_threads = params.max_threads;
|
||||
aggregation_state_cache = AggregatedDataVariants::createCache(method_chosen, cache_settings);
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
compileAggregateFunctions();
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
void Aggregator::compileAggregateFunctions()
|
||||
{
|
||||
if (!params.compile_aggregate_expressions ||
|
||||
params.overflow_row)
|
||||
if (!params.compile_aggregate_expressions || params.overflow_row)
|
||||
return;
|
||||
|
||||
std::vector<AggregateFunctionWithOffset> functions_to_compile;
|
||||
@ -334,7 +343,7 @@ void Aggregator::compileAggregateFunctions()
|
||||
++aggregate_instructions_size;
|
||||
}
|
||||
|
||||
if (functions_to_compile.size() != aggregate_instructions_size)
|
||||
if (functions_to_compile.empty() || functions_to_compile.size() != aggregate_instructions_size)
|
||||
return;
|
||||
|
||||
CompiledAggregateFunctions compiled_aggregate_functions;
|
||||
@ -362,6 +371,8 @@ void Aggregator::compileAggregateFunctions()
|
||||
compiled_functions.emplace(std::move(compiled_aggregate_functions));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
AggregatedDataVariants::Type Aggregator::chooseAggregationMethod()
|
||||
{
|
||||
/// If no keys. All aggregating to single row.
|
||||
@ -574,6 +585,8 @@ void NO_INLINE Aggregator::executeImpl(
|
||||
executeImplBatch<true>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row);
|
||||
}
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
template <bool no_more_keys, typename Method>
|
||||
void NO_INLINE Aggregator::handleAggregationJIT(
|
||||
Method & method,
|
||||
@ -587,7 +600,11 @@ void NO_INLINE Aggregator::handleAggregationJIT(
|
||||
|
||||
/// Add values to the aggregate functions.
|
||||
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
|
||||
columns_data.emplace_back(getColumnData(inst->batch_arguments[0]));
|
||||
{
|
||||
size_t arguments_size = inst->that->getArgumentTypes().size();
|
||||
for (size_t i = 0; i < arguments_size; ++i)
|
||||
columns_data.emplace_back(getColumnData(inst->batch_arguments[i]));
|
||||
}
|
||||
|
||||
auto add_into_aggregate_states_function = compiled_functions->add_into_aggregate_states_function;
|
||||
auto create_aggregate_states_function = compiled_functions->create_aggregate_states_function;
|
||||
@ -635,6 +652,8 @@ void NO_INLINE Aggregator::handleAggregationJIT(
|
||||
add_into_aggregate_states_function(rows, columns_data.data(), get_aggregate_data_function, get_aggregate_data_context);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <bool no_more_keys, typename Method>
|
||||
void NO_INLINE Aggregator::handleAggregationDefault(
|
||||
Method & method,
|
||||
@ -751,10 +770,16 @@ void NO_INLINE Aggregator::executeImplBatch(
|
||||
}
|
||||
}
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
if (compiled_functions)
|
||||
{
|
||||
handleAggregationJIT<no_more_keys>(method, state, aggregates_pool, rows, aggregate_instructions);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
handleAggregationDefault<no_more_keys>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -857,6 +882,39 @@ bool Aggregator::executeOnBlock(const Block & block, AggregatedDataVariants & re
|
||||
bool Aggregator::executeOnBlock(Columns columns, UInt64 num_rows, AggregatedDataVariants & result,
|
||||
ColumnRawPtrs & key_columns, AggregateColumns & aggregate_columns, bool & no_more_keys)
|
||||
{
|
||||
// std::cerr << "Aggregator::executeOnBlock" << std::endl;
|
||||
// std::cerr << "Columns " << columns.size() << std::endl;
|
||||
|
||||
// for (const auto & column : columns)
|
||||
// {
|
||||
// if (column)
|
||||
// std::cerr << column->dumpStructure() << "\n";
|
||||
// }
|
||||
|
||||
// std::cerr << "Num rows " << num_rows << std::endl;
|
||||
// std::cerr << "Key columns before " << key_columns.size() << std::endl;
|
||||
// for (const auto & column : key_columns)
|
||||
// {
|
||||
// if (column)
|
||||
// std::cerr << column->dumpStructure() << "\n";
|
||||
// }
|
||||
|
||||
// std::cerr << "Aggregate columns before " << aggregate_columns.size() << std::endl;
|
||||
// for (size_t i = 0; i < aggregate_columns.size(); ++i)
|
||||
// {
|
||||
// const auto & aggregate_function_columns = aggregate_columns[i];
|
||||
|
||||
// for (const auto & aggregate_function_column : aggregate_function_columns)
|
||||
// {
|
||||
// if (aggregate_function_column)
|
||||
// {
|
||||
// std::cerr << "Aggregate function column " << static_cast<const void *>(aggregate_function_column) << std::endl;
|
||||
// std::cerr << aggregate_function_column->dumpStructure() << "\n";
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// std::cerr << "No more keys " << no_more_keys << std::endl;
|
||||
|
||||
/// `result` will destroy the states of aggregate functions in the destructor
|
||||
result.aggregator = this;
|
||||
|
||||
@ -890,6 +948,7 @@ bool Aggregator::executeOnBlock(Columns columns, UInt64 num_rows, AggregatedData
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
NestedColumnsHolder nested_columns_holder;
|
||||
AggregateFunctionInstructions aggregate_functions_instructions;
|
||||
prepareAggregateInstructions(columns, aggregate_columns, materialized_columns, aggregate_functions_instructions, nested_columns_holder);
|
||||
@ -901,6 +960,28 @@ bool Aggregator::executeOnBlock(Columns columns, UInt64 num_rows, AggregatedData
|
||||
result.without_key = place;
|
||||
}
|
||||
|
||||
// std::cerr << "Key columns after " << key_columns.size() << std::endl;
|
||||
// for (const auto & column : key_columns)
|
||||
// {
|
||||
// if (column)
|
||||
// std::cerr << column->dumpStructure() << "\n";
|
||||
// }
|
||||
|
||||
// std::cerr << "Aggregate columns after " << aggregate_columns.size() << std::endl;
|
||||
// for (size_t i = 0; i < aggregate_columns.size(); ++i)
|
||||
// {
|
||||
// const auto & aggregate_function_columns = aggregate_columns[i];
|
||||
|
||||
// for (const auto & aggregate_function_column : aggregate_function_columns)
|
||||
// {
|
||||
// if (aggregate_function_column)
|
||||
// {
|
||||
// std::cerr << "Aggregate function column " << static_cast<const void *>(aggregate_function_column) << std::endl;
|
||||
// std::cerr << aggregate_function_column->dumpStructure() << "\n";
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
/// We select one of the aggregation methods and call it.
|
||||
|
||||
/// For the case when there are no keys (all aggregate into one row).
|
||||
@ -1291,6 +1372,7 @@ void NO_INLINE Aggregator::convertToBlockImplFinal(
|
||||
auto shuffled_key_sizes = method.shuffleKeyColumns(key_columns, key_sizes);
|
||||
const auto & key_sizes_ref = shuffled_key_sizes ? *shuffled_key_sizes : key_sizes;
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
if (compiled_functions)
|
||||
{
|
||||
std::unique_ptr<AggregateDataPtr[]> places(new AggregateDataPtr[data.size()]);
|
||||
@ -1316,6 +1398,7 @@ void NO_INLINE Aggregator::convertToBlockImplFinal(
|
||||
insert_aggregate_states_function(data.size(), columns_data.data(), places.get());
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
data.forEachValue([&](const auto & key, auto & mapped)
|
||||
{
|
||||
@ -1751,6 +1834,7 @@ void NO_INLINE Aggregator::mergeDataImpl(
|
||||
if constexpr (Method::low_cardinality_optimization)
|
||||
mergeDataNullKey<Method, Table>(table_dst, table_src, arena);
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
if (compiled_functions)
|
||||
{
|
||||
auto merge_aggregate_states_function_typed = compiled_functions->merge_aggregate_states_function;
|
||||
@ -1770,6 +1854,7 @@ void NO_INLINE Aggregator::mergeDataImpl(
|
||||
});
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
table_src.mergeToViaEmplace(table_dst, [&](AggregateDataPtr & __restrict dst, AggregateDataPtr & __restrict src, bool inserted)
|
||||
{
|
||||
|
@ -1083,7 +1083,9 @@ private:
|
||||
/// For external aggregation.
|
||||
TemporaryFiles temporary_files;
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
std::optional<CompiledAggregateFunctions> compiled_functions;
|
||||
#endif
|
||||
|
||||
/** Try to compile aggregate functions.
|
||||
*/
|
||||
|
@ -299,8 +299,6 @@ static void compileCreateAggregateStatesFunctions(llvm::Module & module, const s
|
||||
aggregate_function->compileCreate(b, aggregation_place_with_offset);
|
||||
}
|
||||
|
||||
module.print(llvm::errs(), nullptr);
|
||||
|
||||
b.CreateRetVoid();
|
||||
}
|
||||
|
||||
@ -328,12 +326,29 @@ static void compileAddIntoAggregateStatesFunctions(llvm::Module & module, const
|
||||
auto * entry = llvm::BasicBlock::Create(b.getContext(), "entry", aggregate_loop_func_definition);
|
||||
b.SetInsertPoint(entry);
|
||||
|
||||
std::vector<ColumnDataPlaceholder> columns(functions.size());
|
||||
std::vector<ColumnDataPlaceholder> columns;
|
||||
size_t previous_columns_size = 0;
|
||||
|
||||
for (size_t i = 0; i < functions.size(); ++i)
|
||||
{
|
||||
auto argument_type = functions[i].function->getArgumentTypes()[0];
|
||||
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, i));
|
||||
columns[i].data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(argument_type))->getPointerTo());
|
||||
auto argument_types = functions[i].function->getArgumentTypes();
|
||||
|
||||
ColumnDataPlaceholder data_placeholder;
|
||||
|
||||
std::cerr << "Function " << functions[i].function->getName() << std::endl;
|
||||
|
||||
size_t function_arguments_size = argument_types.size();
|
||||
|
||||
for (size_t column_argument_index = 0; column_argument_index < function_arguments_size; ++column_argument_index)
|
||||
{
|
||||
const auto & argument_type = argument_types[previous_columns_size + column_argument_index];
|
||||
auto * data = b.CreateLoad(column_data_type, b.CreateConstInBoundsGEP1_32(column_data_type, columns_arg, column_argument_index));
|
||||
std::cerr << "Argument type " << argument_type->getName() << std::endl;
|
||||
data_placeholder.data_init = b.CreatePointerCast(b.CreateExtractValue(data, {0}), toNativeType(b, removeNullable(argument_type))->getPointerTo());
|
||||
columns.emplace_back(data_placeholder);
|
||||
}
|
||||
|
||||
previous_columns_size += function_arguments_size;
|
||||
}
|
||||
|
||||
/// Initialize loop
|
||||
@ -356,16 +371,28 @@ static void compileAddIntoAggregateStatesFunctions(llvm::Module & module, const
|
||||
|
||||
auto * aggregation_place = b.CreateCall(get_place_func_declaration, get_place_function_arg, { get_place_function_context_arg, counter_phi });
|
||||
|
||||
for (size_t i = 0; i < functions.size(); ++i)
|
||||
previous_columns_size = 0;
|
||||
for (const auto & function : functions)
|
||||
{
|
||||
size_t aggregate_function_offset = functions[i].aggregate_data_offset;
|
||||
const auto * aggregate_function_ptr = functions[i].function;
|
||||
size_t aggregate_function_offset = function.aggregate_data_offset;
|
||||
const auto * aggregate_function_ptr = function.function;
|
||||
|
||||
auto arguments_types = function.function->getArgumentTypes();
|
||||
std::vector<llvm::Value *> arguments_values;
|
||||
|
||||
size_t function_arguments_size = arguments_types.size();
|
||||
arguments_values.resize(function_arguments_size);
|
||||
|
||||
for (size_t column_argument_index = 0; column_argument_index < function_arguments_size; ++column_argument_index)
|
||||
{
|
||||
auto * column_argument_data = columns[previous_columns_size + column_argument_index].data;
|
||||
arguments_values[column_argument_index] = b.CreateLoad(toNativeType(b, arguments_types[column_argument_index]), column_argument_data);
|
||||
}
|
||||
|
||||
auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_32(nullptr, aggregation_place, aggregate_function_offset);
|
||||
aggregate_function_ptr->compileAdd(b, aggregation_place_with_offset, arguments_types, arguments_values);
|
||||
|
||||
auto column_type = functions[i].function->getArgumentTypes()[0];
|
||||
auto * column_data = b.CreateLoad(toNativeType(b, column_type), columns[i].data);
|
||||
aggregate_function_ptr->compileAdd(b, aggregation_place_with_offset, column_type, column_data);
|
||||
previous_columns_size += function_arguments_size;
|
||||
}
|
||||
|
||||
/// End of loop
|
||||
@ -374,12 +401,13 @@ static void compileAddIntoAggregateStatesFunctions(llvm::Module & module, const
|
||||
for (auto & col : columns)
|
||||
{
|
||||
col.data->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.data, 1), cur_block);
|
||||
|
||||
if (col.null)
|
||||
col.null->addIncoming(b.CreateConstInBoundsGEP1_32(nullptr, col.null, 1), cur_block);
|
||||
}
|
||||
|
||||
auto * value = b.CreateAdd(counter_phi, llvm::ConstantInt::get(size_type, 1));
|
||||
counter_phi->addIncoming(value, loop);
|
||||
counter_phi->addIncoming(value, cur_block);
|
||||
|
||||
b.CreateCondBr(b.CreateICmpEQ(value, rows_count_arg), end, loop);
|
||||
|
||||
|
@ -1,5 +1,11 @@
|
||||
#include <iostream>
|
||||
|
||||
#if !defined(ARCADIA_BUILD)
|
||||
# include "config_core.h"
|
||||
#endif
|
||||
|
||||
#if USE_EMBEDDED_COMPILER
|
||||
|
||||
#include <llvm/IR/IRBuilder.h>
|
||||
|
||||
#include <Interpreters/JIT/CHJIT.h>
|
||||
@ -56,3 +62,14 @@ int main(int argc, char **argv)
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
(void)(argc);
|
||||
(void)(argv);
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user