From 129b97c72783e64b0bc7485b72ec1a82f8db0045 Mon Sep 17 00:00:00 2001 From: Maksim Kita Date: Sat, 3 Jun 2023 21:31:08 +0300 Subject: [PATCH 1/5] JIT infrastructure refactoring --- src/AggregateFunctions/AggregateFunctionAvg.h | 12 +- .../AggregateFunctionAvgWeighted.h | 11 +- .../AggregateFunctionBitwise.h | 5 +- .../AggregateFunctionCount.h | 6 +- .../AggregateFunctionIf.cpp | 48 ++- src/AggregateFunctions/AggregateFunctionIf.h | 24 +- .../AggregateFunctionMinMaxAny.h | 4 +- .../AggregateFunctionNull.h | 39 +-- src/AggregateFunctions/AggregateFunctionSum.h | 7 +- src/AggregateFunctions/IAggregateFunction.h | 3 +- src/Core/ValueWithType.h | 26 ++ src/Core/ValuesWithType.h | 13 + src/DataTypes/IDataType.h | 5 - src/DataTypes/Native.cpp | 199 ++++++++++++ src/DataTypes/Native.h | 284 ++++-------------- src/Functions/FunctionBinaryArithmetic.h | 35 ++- src/Functions/FunctionIfBase.h | 24 +- src/Functions/FunctionUnaryArithmetic.h | 29 +- src/Functions/FunctionsComparison.h | 20 +- src/Functions/FunctionsLogical.h | 33 +- src/Functions/IFunction.cpp | 64 ++-- src/Functions/IFunction.h | 13 +- src/Functions/IFunctionAdaptors.h | 6 +- src/Interpreters/ExpressionJIT.cpp | 4 +- src/Interpreters/JIT/CHJIT.h | 4 +- src/Interpreters/JIT/CompileDAG.cpp | 21 +- src/Interpreters/JIT/CompileDAG.h | 2 +- src/Interpreters/JIT/compileFunction.cpp | 24 +- 28 files changed, 523 insertions(+), 442 deletions(-) create mode 100644 src/Core/ValueWithType.h create mode 100644 src/Core/ValuesWithType.h create mode 100644 src/DataTypes/Native.cpp diff --git a/src/AggregateFunctions/AggregateFunctionAvg.h b/src/AggregateFunctions/AggregateFunctionAvg.h index a86c7d042fc..37f20fca01c 100644 --- a/src/AggregateFunctions/AggregateFunctionAvg.h +++ b/src/AggregateFunctions/AggregateFunctionAvg.h @@ -146,8 +146,8 @@ public: for (const auto & argument : this->argument_types) can_be_compiled &= canBeNativeType(*argument); - auto return_type = this->getResultType(); - can_be_compiled &= canBeNativeType(*return_type); + const auto & result_type = this->getResultType(); + can_be_compiled &= canBeNativeType(*result_type); return can_be_compiled; } @@ -198,8 +198,8 @@ public: auto * denominator_ptr = b.CreateConstGEP1_32(b.getInt8Ty(), aggregate_data_ptr, denominator_offset); auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr); - auto * double_numerator = nativeCast(b, numerator_value, b.getDoubleTy()); - auto * double_denominator = nativeCast(b, denominator_value, b.getDoubleTy()); + auto * double_numerator = nativeCast(b, numerator_value, this->getResultType()); + auto * double_denominator = nativeCast(b, denominator_value, this->getResultType()); return b.CreateFDiv(double_numerator, double_denominator); } @@ -308,7 +308,7 @@ public: #if USE_EMBEDDED_COMPILER - void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector & argument_values) const override + void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override { llvm::IRBuilder<> & b = static_cast &>(builder); @@ -316,7 +316,7 @@ public: auto * numerator_ptr = aggregate_data_ptr; auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr); - auto * value_cast_to_numerator = nativeCast(b, arguments_types[0], argument_values[0], numerator_type); + auto * value_cast_to_numerator = nativeCast(b, arguments[0], toNativeDataType()); auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_cast_to_numerator) : b.CreateFAdd(numerator_value, value_cast_to_numerator); b.CreateStore(numerator_result_value, numerator_ptr); diff --git a/src/AggregateFunctions/AggregateFunctionAvgWeighted.h b/src/AggregateFunctions/AggregateFunctionAvgWeighted.h index bc3e3a32a71..5a3869032ca 100644 --- a/src/AggregateFunctions/AggregateFunctionAvgWeighted.h +++ b/src/AggregateFunctions/AggregateFunctionAvgWeighted.h @@ -30,7 +30,7 @@ public: using Numerator = typename Base::Numerator; using Denominator = typename Base::Denominator; - using Fraction = typename Base::Fraction; + using Fraction = typename Base::Fraction; void NO_SANITIZE_UNDEFINED add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override { @@ -55,7 +55,7 @@ public: return can_be_compiled; } - void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector & argument_values) const override + void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override { llvm::IRBuilder<> & b = static_cast &>(builder); @@ -63,8 +63,9 @@ public: auto * numerator_ptr = aggregate_data_ptr; auto * numerator_value = b.CreateLoad(numerator_type, numerator_ptr); - auto * argument = nativeCast(b, arguments_types[0], argument_values[0], numerator_type); - auto * weight = nativeCast(b, arguments_types[1], argument_values[1], numerator_type); + auto numerator_data_type = toNativeDataType(); + auto * argument = nativeCast(b, arguments[0], numerator_data_type); + auto * weight = nativeCast(b, arguments[1], numerator_data_type); llvm::Value * value_weight_multiplication = argument->getType()->isIntegerTy() ? b.CreateMul(argument, weight) : b.CreateFMul(argument, weight); auto * numerator_result_value = numerator_type->isIntegerTy() ? b.CreateAdd(numerator_value, value_weight_multiplication) : b.CreateFAdd(numerator_value, value_weight_multiplication); @@ -75,7 +76,7 @@ public: static constexpr size_t denominator_offset = offsetof(Fraction, denominator); auto * denominator_ptr = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_ptr, denominator_offset); - auto * weight_cast_to_denominator = nativeCast(b, arguments_types[1], argument_values[1], denominator_type); + auto * weight_cast_to_denominator = nativeCast(b, arguments[1], toNativeDataType()); auto * denominator_value = b.CreateLoad(denominator_type, denominator_ptr); auto * denominator_value_updated = denominator_type->isIntegerTy() ? b.CreateAdd(denominator_value, weight_cast_to_denominator) : b.CreateFAdd(denominator_value, weight_cast_to_denominator); diff --git a/src/AggregateFunctions/AggregateFunctionBitwise.h b/src/AggregateFunctions/AggregateFunctionBitwise.h index 6c94a72bf32..71479b309c7 100644 --- a/src/AggregateFunctions/AggregateFunctionBitwise.h +++ b/src/AggregateFunctions/AggregateFunctionBitwise.h @@ -148,7 +148,7 @@ public: Data::compileCreate(builder, value_ptr); } - void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes &, const std::vector & argument_values) const override + void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override { llvm::IRBuilder<> & b = static_cast &>(builder); @@ -157,8 +157,7 @@ public: auto * value_ptr = aggregate_data_ptr; auto * value = b.CreateLoad(return_type, value_ptr); - const auto & argument_value = argument_values[0]; - auto * result_value = Data::compileUpdate(builder, value, argument_value); + auto * result_value = Data::compileUpdate(builder, value, arguments[0].value); b.CreateStore(result_value, value_ptr); } diff --git a/src/AggregateFunctions/AggregateFunctionCount.h b/src/AggregateFunctions/AggregateFunctionCount.h index 848a8a4b603..77d3bfeb448 100644 --- a/src/AggregateFunctions/AggregateFunctionCount.h +++ b/src/AggregateFunctions/AggregateFunctionCount.h @@ -165,7 +165,7 @@ public: b.CreateMemSet(aggregate_data_ptr, llvm::ConstantInt::get(b.getInt8Ty(), 0), sizeof(AggregateFunctionCountData), llvm::assumeAligned(this->alignOfData())); } - void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes &, const std::vector &) const override + void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType &) const override { llvm::IRBuilder<> & b = static_cast &>(builder); @@ -309,13 +309,13 @@ public: b.CreateMemSet(aggregate_data_ptr, llvm::ConstantInt::get(b.getInt8Ty(), 0), sizeof(AggregateFunctionCountData), llvm::assumeAligned(this->alignOfData())); } - void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes &, const std::vector & values) const override + void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override { llvm::IRBuilder<> & b = static_cast &>(builder); auto * return_type = toNativeType(b, this->getResultType()); - auto * is_null_value = b.CreateExtractValue(values[0], {1}); + auto * is_null_value = b.CreateExtractValue(arguments[0].value, {1}); auto * increment_value = b.CreateSelect(is_null_value, llvm::ConstantInt::get(return_type, 0), llvm::ConstantInt::get(return_type, 1)); auto * count_value_ptr = aggregate_data_ptr; diff --git a/src/AggregateFunctions/AggregateFunctionIf.cpp b/src/AggregateFunctions/AggregateFunctionIf.cpp index 20bdb32796a..87fa8239507 100644 --- a/src/AggregateFunctions/AggregateFunctionIf.cpp +++ b/src/AggregateFunctions/AggregateFunctionIf.cpp @@ -188,18 +188,18 @@ public: return canBeNativeType(*this->argument_types.back()) && this->nested_function->isCompilable(); } - void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector & argument_values) const override + void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override { llvm::IRBuilder<> & b = static_cast &>(builder); - const auto & nullable_type = arguments_types[0]; - const auto & nullable_value = argument_values[0]; + const auto & nullable_type = arguments[0].type; + const auto & nullable_value = arguments[0].value; auto * wrapped_value = b.CreateExtractValue(nullable_value, {0}); auto * is_null_value = b.CreateExtractValue(nullable_value, {1}); - const auto & predicate_type = arguments_types[argument_values.size() - 1]; - auto * predicate_value = argument_values[argument_values.size() - 1]; + const auto & predicate_type = arguments.back().type; + auto * predicate_value = arguments.back().value; auto * is_predicate_true = nativeBoolCast(b, predicate_type, predicate_value); auto * head = b.GetInsertBlock(); @@ -219,7 +219,7 @@ public: b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr); auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_ptr, this->prefix_size); - this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, { removeNullable(nullable_type) }, { wrapped_value }); + this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, { ValueWithType(wrapped_value, removeNullable(nullable_type)) }); b.CreateBr(join_block); b.SetInsertPoint(join_block); @@ -370,38 +370,31 @@ public: return canBeNativeType(*this->argument_types.back()) && this->nested_function->isCompilable(); } - void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector & argument_values) const override + void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override { - /// TODO: Check - llvm::IRBuilder<> & b = static_cast &>(builder); - size_t arguments_size = arguments_types.size(); + size_t arguments_size = arguments.size(); + + ValuesWithType wrapped_arguments; + wrapped_arguments.reserve(arguments_size); - DataTypes non_nullable_types; - std::vector wrapped_values; std::vector is_null_values; - non_nullable_types.resize(arguments_size); - wrapped_values.resize(arguments_size); - is_null_values.resize(arguments_size); - for (size_t i = 0; i < arguments_size; ++i) { - const auto & argument_value = argument_values[i]; + const auto & argument_value = arguments[i].value; + const auto & argument_type = arguments[i].type; if (is_nullable[i]) { auto * wrapped_value = b.CreateExtractValue(argument_value, {0}); - is_null_values[i] = b.CreateExtractValue(argument_value, {1}); - - wrapped_values[i] = wrapped_value; - non_nullable_types[i] = removeNullable(arguments_types[i]); + is_null_values.emplace_back(b.CreateExtractValue(argument_value, {1})); + wrapped_arguments.emplace_back(wrapped_value, removeNullable(argument_type)); } else { - wrapped_values[i] = argument_value; - non_nullable_types[i] = arguments_types[i]; + wrapped_arguments.emplace_back(argument_value, argument_type); } } @@ -415,9 +408,6 @@ public: for (auto * is_null_value : is_null_values) { - if (!is_null_value) - continue; - auto * values_have_null = b.CreateLoad(b.getInt1Ty(), values_have_null_ptr); b.CreateStore(b.CreateOr(values_have_null, is_null_value), values_have_null_ptr); } @@ -426,8 +416,8 @@ public: b.SetInsertPoint(join_block_after_null_checks); - const auto & predicate_type = arguments_types[argument_values.size() - 1]; - auto * predicate_value = argument_values[argument_values.size() - 1]; + const auto & predicate_type = arguments.back().type; + auto * predicate_value = arguments.back().value; auto * is_predicate_true = nativeBoolCast(b, predicate_type, predicate_value); auto * if_true = llvm::BasicBlock::Create(head->getContext(), "if_true", head->getParent()); @@ -444,7 +434,7 @@ public: b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr); auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_ptr, this->prefix_size); - this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, non_nullable_types, wrapped_values); + this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, wrapped_arguments); b.CreateBr(join_block); b.SetInsertPoint(join_block); diff --git a/src/AggregateFunctions/AggregateFunctionIf.h b/src/AggregateFunctions/AggregateFunctionIf.h index cd7d7e27a25..afab861e202 100644 --- a/src/AggregateFunctions/AggregateFunctionIf.h +++ b/src/AggregateFunctions/AggregateFunctionIf.h @@ -223,12 +223,12 @@ public: nested_func->compileCreate(builder, aggregate_data_ptr); } - void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector & argument_values) const override + void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override { llvm::IRBuilder<> & b = static_cast &>(builder); - const auto & predicate_type = arguments_types[argument_values.size() - 1]; - auto * predicate_value = argument_values[argument_values.size() - 1]; + const auto & predicate_type = arguments.back().type; + auto * predicate_value = arguments.back().value; auto * head = b.GetInsertBlock(); @@ -242,21 +242,9 @@ public: b.SetInsertPoint(if_true); - size_t arguments_size_without_predicate = arguments_types.size() - 1; - - DataTypes argument_types_without_predicate; - std::vector argument_values_without_predicate; - - argument_types_without_predicate.resize(arguments_size_without_predicate); - argument_values_without_predicate.resize(arguments_size_without_predicate); - - for (size_t i = 0; i < arguments_size_without_predicate; ++i) - { - argument_types_without_predicate[i] = arguments_types[i]; - argument_values_without_predicate[i] = argument_values[i]; - } - - nested_func->compileAdd(builder, aggregate_data_ptr, argument_types_without_predicate, argument_values_without_predicate); + ValuesWithType arguments_without_predicate = arguments; + arguments_without_predicate.pop_back(); + nested_func->compileAdd(builder, aggregate_data_ptr, arguments_without_predicate); b.CreateBr(join_block); diff --git a/src/AggregateFunctions/AggregateFunctionMinMaxAny.h b/src/AggregateFunctions/AggregateFunctionMinMaxAny.h index 94c0d60be81..5312df32459 100644 --- a/src/AggregateFunctions/AggregateFunctionMinMaxAny.h +++ b/src/AggregateFunctions/AggregateFunctionMinMaxAny.h @@ -1459,11 +1459,11 @@ public: b.CreateMemSet(aggregate_data_ptr, llvm::ConstantInt::get(b.getInt8Ty(), 0), this->sizeOfData(), llvm::assumeAligned(this->alignOfData())); } - void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes &, const std::vector & argument_values) const override + void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override { if constexpr (Data::is_compilable) { - Data::compileChangeIfBetter(builder, aggregate_data_ptr, argument_values[0]); + Data::compileChangeIfBetter(builder, aggregate_data_ptr, arguments[0].value); } else { diff --git a/src/AggregateFunctions/AggregateFunctionNull.h b/src/AggregateFunctions/AggregateFunctionNull.h index de7b190c949..6b6580bf4c4 100644 --- a/src/AggregateFunctions/AggregateFunctionNull.h +++ b/src/AggregateFunctions/AggregateFunctionNull.h @@ -378,12 +378,12 @@ public: #if USE_EMBEDDED_COMPILER - void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector & argument_values) const override + void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override { llvm::IRBuilder<> & b = static_cast &>(builder); - const auto & nullable_type = arguments_types[0]; - const auto & nullable_value = argument_values[0]; + const auto & nullable_type = arguments[0].type; + const auto & nullable_value = arguments[0].value; auto * wrapped_value = b.CreateExtractValue(nullable_value, {0}); auto * is_null_value = b.CreateExtractValue(nullable_value, {1}); @@ -405,7 +405,7 @@ public: b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr); auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_ptr, this->prefix_size); - this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, { removeNullable(nullable_type) }, { wrapped_value }); + this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, { ValueWithType(wrapped_value, removeNullable(nullable_type)) }); b.CreateBr(join_block); b.SetInsertPoint(join_block); @@ -568,36 +568,32 @@ public: #if USE_EMBEDDED_COMPILER - void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector & argument_values) const override + void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override { llvm::IRBuilder<> & b = static_cast &>(builder); - size_t arguments_size = arguments_types.size(); + size_t arguments_size = arguments.size(); - DataTypes non_nullable_types; - std::vector wrapped_values; - std::vector is_null_values; + ValuesWithType wrapped_arguments; + wrapped_arguments.reserve(arguments_size); - non_nullable_types.resize(arguments_size); - wrapped_values.resize(arguments_size); - is_null_values.resize(arguments_size); + std::vector is_null_values; + is_null_values.reserve(arguments_size); for (size_t i = 0; i < arguments_size; ++i) { - const auto & argument_value = argument_values[i]; + const auto & argument_value = arguments[i].value; + const auto & argument_type = arguments[i].type; if (is_nullable[i]) { auto * wrapped_value = b.CreateExtractValue(argument_value, {0}); - is_null_values[i] = b.CreateExtractValue(argument_value, {1}); - - wrapped_values[i] = wrapped_value; - non_nullable_types[i] = removeNullable(arguments_types[i]); + is_null_values.emplace_back(b.CreateExtractValue(argument_value, {1})); + wrapped_arguments.emplace_back(wrapped_value, removeNullable(argument_type)); } else { - wrapped_values[i] = argument_value; - non_nullable_types[i] = arguments_types[i]; + wrapped_arguments.emplace_back(argument_value, argument_type); } } @@ -612,9 +608,6 @@ public: for (auto * is_null_value : is_null_values) { - if (!is_null_value) - continue; - auto * values_have_null = b.CreateLoad(b.getInt1Ty(), values_have_null_ptr); b.CreateStore(b.CreateOr(values_have_null, is_null_value), values_have_null_ptr); } @@ -630,7 +623,7 @@ public: b.CreateStore(llvm::ConstantInt::get(b.getInt8Ty(), 1), aggregate_data_ptr); auto * aggregate_data_ptr_with_prefix_size_offset = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregate_data_ptr, this->prefix_size); - this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, arguments_types, wrapped_values); + this->nested_function->compileAdd(b, aggregate_data_ptr_with_prefix_size_offset, wrapped_arguments); b.CreateBr(join_block); b.SetInsertPoint(join_block); diff --git a/src/AggregateFunctions/AggregateFunctionSum.h b/src/AggregateFunctions/AggregateFunctionSum.h index f77d1dae36f..bb0804c14b3 100644 --- a/src/AggregateFunctions/AggregateFunctionSum.h +++ b/src/AggregateFunctions/AggregateFunctionSum.h @@ -588,7 +588,7 @@ public: b.CreateStore(llvm::Constant::getNullValue(return_type), aggregate_sum_ptr); } - void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const DataTypes & arguments_types, const std::vector & argument_values) const override + void compileAdd(llvm::IRBuilderBase & builder, llvm::Value * aggregate_data_ptr, const ValuesWithType & arguments) const override { llvm::IRBuilder<> & b = static_cast &>(builder); @@ -597,10 +597,7 @@ public: auto * sum_value_ptr = aggregate_data_ptr; auto * sum_value = b.CreateLoad(return_type, sum_value_ptr); - const auto & argument_type = arguments_types[0]; - const auto & argument_value = argument_values[0]; - - auto * value_cast_to_result = nativeCast(b, argument_type, argument_value, return_type); + auto * value_cast_to_result = nativeCast(b, arguments[0], this->getResultType()); auto * sum_result_value = sum_value->getType()->isIntegerTy() ? b.CreateAdd(sum_value, value_cast_to_result) : b.CreateFAdd(sum_value, value_cast_to_result); b.CreateStore(sum_result_value, sum_value_ptr); diff --git a/src/AggregateFunctions/IAggregateFunction.h b/src/AggregateFunctions/IAggregateFunction.h index ddc0535d0e4..df08b6f2109 100644 --- a/src/AggregateFunctions/IAggregateFunction.h +++ b/src/AggregateFunctions/IAggregateFunction.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -389,7 +390,7 @@ public: } /// compileAdd should generate code for updating aggregate function state stored in aggregate_data_ptr - virtual void compileAdd(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/, const DataTypes & /*arguments_types*/, const std::vector & /*arguments_values*/) const + virtual void compileAdd(llvm::IRBuilderBase & /*builder*/, llvm::Value * /*aggregate_data_ptr*/, const ValuesWithType & /*arguments*/) const { throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName()); } diff --git a/src/Core/ValueWithType.h b/src/Core/ValueWithType.h new file mode 100644 index 00000000000..b5f61a1c5f7 --- /dev/null +++ b/src/Core/ValueWithType.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +namespace llvm +{ + class Value; +} + +namespace DB +{ + +/// LLVM value with its data type +struct ValueWithType +{ + llvm::Value * value = nullptr; + DataTypePtr type; + + ValueWithType() = default; + ValueWithType(llvm::Value * value_, DataTypePtr type_) + : value(value_) + , type(std::move(type_)) + {} +}; + +} diff --git a/src/Core/ValuesWithType.h b/src/Core/ValuesWithType.h new file mode 100644 index 00000000000..92060419197 --- /dev/null +++ b/src/Core/ValuesWithType.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +#include + + +namespace DB +{ + +using ValuesWithType = std::vector; + +} diff --git a/src/DataTypes/IDataType.h b/src/DataTypes/IDataType.h index 7cc18fea00c..7a705e8fd19 100644 --- a/src/DataTypes/IDataType.h +++ b/src/DataTypes/IDataType.h @@ -532,11 +532,6 @@ inline bool isNotDecimalButComparableToDecimal(const DataTypePtr & data_type) return which.isInt() || which.isUInt() || which.isFloat(); } -inline bool isCompilableType(const DataTypePtr & data_type) -{ - return data_type->isValueRepresentedByNumber() && !isDecimal(data_type); -} - inline bool isBool(const DataTypePtr & data_type) { return data_type->getName() == "Bool"; diff --git a/src/DataTypes/Native.cpp b/src/DataTypes/Native.cpp new file mode 100644 index 00000000000..acbd70ba04f --- /dev/null +++ b/src/DataTypes/Native.cpp @@ -0,0 +1,199 @@ +#include + +#if USE_EMBEDDED_COMPILER +# include +# include +# include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NOT_IMPLEMENTED; +} + +bool typeIsSigned(const IDataType & type) +{ + WhichDataType data_type(type); + return data_type.isNativeInt() || data_type.isFloat() || data_type.isEnum() || data_type.isDate32(); +} + +llvm::Type * toNullableType(llvm::IRBuilderBase & builder, llvm::Type * type) +{ + auto * is_null_type = builder.getInt1Ty(); + return llvm::StructType::get(type, is_null_type); +} + +bool canBeNativeType(const IDataType & type) +{ + WhichDataType data_type(type); + + if (data_type.isNullable()) + { + const auto & data_type_nullable = static_cast(type); + return canBeNativeType(*data_type_nullable.getNestedType()); + } + + return data_type.isNativeInt() || data_type.isNativeUInt() || data_type.isFloat() || data_type.isDate() + || data_type.isDate32() || data_type.isDateTime() || data_type.isEnum(); +} + +bool canBeNativeType(const DataTypePtr & type) +{ + return canBeNativeType(*type); +} + +llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const IDataType & type) +{ + WhichDataType data_type(type); + + if (data_type.isNullable()) + { + const auto & data_type_nullable = static_cast(type); + auto * nested_type = toNativeType(builder, *data_type_nullable.getNestedType()); + return toNullableType(builder, nested_type); + } + + /// LLVM doesn't have unsigned types, it has unsigned instructions. + if (data_type.isInt8() || data_type.isUInt8()) + return builder.getInt8Ty(); + else if (data_type.isInt16() || data_type.isUInt16() || data_type.isDate()) + return builder.getInt16Ty(); + else if (data_type.isInt32() || data_type.isUInt32() || data_type.isDate32() || data_type.isDateTime()) + return builder.getInt32Ty(); + else if (data_type.isInt64() || data_type.isUInt64()) + return builder.getInt64Ty(); + else if (data_type.isFloat32()) + return builder.getFloatTy(); + else if (data_type.isFloat64()) + return builder.getDoubleTy(); + else if (data_type.isEnum8()) + return builder.getInt8Ty(); + else if (data_type.isEnum16()) + return builder.getInt16Ty(); + + throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid cast to native type"); +} + +llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const DataTypePtr & type) +{ + return toNativeType(builder, *type); +} + +llvm::Value * nativeBoolCast(llvm::IRBuilderBase & b, const DataTypePtr & from_type, llvm::Value * value) +{ + if (from_type->isNullable()) + { + auto * inner = nativeBoolCast(b, removeNullable(from_type), b.CreateExtractValue(value, {0})); + return b.CreateAnd(b.CreateNot(b.CreateExtractValue(value, {1})), inner); + } + + auto * zero = llvm::Constant::getNullValue(value->getType()); + + if (value->getType()->isIntegerTy()) + return b.CreateICmpNE(value, zero); + else if (value->getType()->isFloatingPointTy()) + return b.CreateFCmpUNE(value, zero); + + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot cast non-number {} to bool", from_type->getName()); +} + +llvm::Value * nativeBoolCast(llvm::IRBuilderBase & b, const ValueWithType & value_with_type) +{ + return nativeBoolCast(b, value_with_type.type, value_with_type.value); +} + +llvm::Value * nativeCast(llvm::IRBuilderBase & b, const DataTypePtr & from_type, llvm::Value * value, const DataTypePtr & to_type) +{ + if (from_type->equals(*to_type)) + { + return value; + } + else if (from_type->isNullable() && to_type->isNullable()) + { + auto * inner = nativeCast(b, removeNullable(from_type), b.CreateExtractValue(value, {0}), to_type); + return b.CreateInsertValue(inner, b.CreateExtractValue(value, {1}), {1}); + } + else if (from_type->isNullable()) + { + return nativeCast(b, removeNullable(from_type), b.CreateExtractValue(value, {0}), to_type); + } + else if (to_type->isNullable()) + { + auto * from_native_type = toNativeType(b, from_type); + auto * inner = nativeCast(b, from_type, value, removeNullable(to_type)); + return b.CreateInsertValue(llvm::Constant::getNullValue(from_native_type), inner, {0}); + } + else + { + auto * from_native_type = toNativeType(b, from_type); + auto * to_native_type = toNativeType(b, to_type); + + if (from_native_type == to_native_type) + return value; + else if (from_native_type->isIntegerTy() && to_native_type->isFloatingPointTy()) + return typeIsSigned(*from_type) ? b.CreateSIToFP(value, to_native_type) : b.CreateUIToFP(value, to_native_type); + else if (from_native_type->isFloatingPointTy() && to_native_type->isIntegerTy()) + return typeIsSigned(*to_type) ? b.CreateFPToSI(value, to_native_type) : b.CreateFPToUI(value, to_native_type); + else if (from_native_type->isIntegerTy() && from_native_type->isIntegerTy()) + return b.CreateIntCast(value, to_native_type, typeIsSigned(*from_type)); + else if (to_native_type->isFloatingPointTy() && to_native_type->isFloatingPointTy()) + return b.CreateFPCast(value, to_native_type); + } + + throw Exception(ErrorCodes::LOGICAL_ERROR, + "Invalid cast to native value from type {} to type {}", + from_type->getName(), + to_type->getName()); +} + +llvm::Value * nativeCast(llvm::IRBuilderBase & b, const ValueWithType & value, const DataTypePtr & to_type) +{ + return nativeCast(b, value.type, value.value, to_type); +} + +llvm::Constant * getColumnNativeValue(llvm::IRBuilderBase & builder, const DataTypePtr & column_type, const IColumn & column, size_t index) +{ + if (const auto * constant = typeid_cast(&column)) + return getColumnNativeValue(builder, column_type, constant->getDataColumn(), 0); + + auto * type = toNativeType(builder, column_type); + + WhichDataType column_data_type(column_type); + if (column_data_type.isNullable()) + { + const auto & nullable_data_type = assert_cast(*column_type); + const auto & nullable_column = assert_cast(column); + + auto * value = getColumnNativeValue(builder, nullable_data_type.getNestedType(), nullable_column.getNestedColumn(), index); + auto * is_null = llvm::ConstantInt::get(type->getContainedType(1), nullable_column.isNullAt(index)); + + return llvm::ConstantStruct::get(static_cast(type), value, is_null); + } + else if (column_data_type.isFloat32()) + { + return llvm::ConstantFP::get(type, assert_cast &>(column).getElement(index)); + } + else if (column_data_type.isFloat64()) + { + return llvm::ConstantFP::get(type, assert_cast &>(column).getElement(index)); + } + else if (column_data_type.isNativeUInt() || column_data_type.isDate() || column_data_type.isDateTime()) + { + return llvm::ConstantInt::get(type, column.getUInt(index)); + } + else if (column_data_type.isNativeInt() || column_data_type.isEnum() || column_data_type.isDate32()) + { + return llvm::ConstantInt::get(type, column.getInt(index)); + } + + throw Exception(ErrorCodes::LOGICAL_ERROR, + "Cannot get native value for column with type {}", + column_type->getName()); +} + +} + +#endif diff --git a/src/DataTypes/Native.h b/src/DataTypes/Native.h index a3c8486fa60..7fee452b1f0 100644 --- a/src/DataTypes/Native.h +++ b/src/DataTypes/Native.h @@ -4,65 +4,53 @@ #if USE_EMBEDDED_COMPILER # include - +# include # include -# include -# include -# include # include namespace DB { + namespace ErrorCodes { - extern const int NOT_IMPLEMENTED; + extern const int LOGICAL_ERROR; } -static inline bool typeIsSigned(const IDataType & type) +/// Returns true if type is signed, false otherwise +bool typeIsSigned(const IDataType & type); + +/// Cast LLVM type to nullable LLVM type +llvm::Type * toNullableType(llvm::IRBuilderBase & builder, llvm::Type * type); + +/// Returns true if type can be native LLVM type, false otherwise +bool canBeNativeType(const IDataType & type); + +/// Returns true if type can be native LLVM type, false otherwise +bool canBeNativeType(const DataTypePtr & type); + +template +static inline bool canBeNativeType() { - WhichDataType data_type(type); - return data_type.isNativeInt() || data_type.isFloat() || data_type.isEnum(); + if constexpr (std::is_same_v || std::is_same_v) + return true; + else if constexpr (std::is_same_v || std::is_same_v) + return true; + else if constexpr (std::is_same_v || std::is_same_v) + return true; + else if constexpr (std::is_same_v || std::is_same_v) + return true; + else if constexpr (std::is_same_v || std::is_same_v) + return true; + + return false; } -static inline llvm::Type * toNullableType(llvm::IRBuilderBase & builder, llvm::Type * type) -{ - auto * is_null_type = builder.getInt1Ty(); - return llvm::StructType::get(type, is_null_type); -} +/// Cast type to native LLVM type +llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const IDataType & type); -static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const IDataType & type) -{ - WhichDataType data_type(type); - - if (data_type.isNullable()) - { - const auto & data_type_nullable = static_cast(type); - auto * wrapped = toNativeType(builder, *data_type_nullable.getNestedType()); - auto * is_null_type = builder.getInt1Ty(); - return wrapped ? llvm::StructType::get(wrapped, is_null_type) : nullptr; - } - - /// LLVM doesn't have unsigned types, it has unsigned instructions. - if (data_type.isInt8() || data_type.isUInt8()) - return builder.getInt8Ty(); - else if (data_type.isInt16() || data_type.isUInt16() || data_type.isDate()) - return builder.getInt16Ty(); - else if (data_type.isInt32() || data_type.isUInt32() || data_type.isDate32() || data_type.isDateTime()) - return builder.getInt32Ty(); - else if (data_type.isInt64() || data_type.isUInt64()) - return builder.getInt64Ty(); - else if (data_type.isFloat32()) - return builder.getFloatTy(); - else if (data_type.isFloat64()) - return builder.getDoubleTy(); - else if (data_type.isEnum8()) - return builder.getInt8Ty(); - else if (data_type.isEnum16()) - return builder.getInt16Ty(); - - return nullptr; -} +/// Cast type to native LLVM type +llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const DataTypePtr & type); template static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder) @@ -80,203 +68,43 @@ static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder) else if constexpr (std::is_same_v) return builder.getDoubleTy(); - return nullptr; + throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid cast to native type"); } -template -static inline bool canBeNativeType() +template +static inline DataTypePtr toNativeDataType() { - if constexpr (std::is_same_v || std::is_same_v) - return true; - else if constexpr (std::is_same_v || std::is_same_v) - return true; - else if constexpr (std::is_same_v || std::is_same_v) - return true; - else if constexpr (std::is_same_v || std::is_same_v) - return true; - else if constexpr (std::is_same_v) - return true; - else if constexpr (std::is_same_v) - return true; + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + return std::make_shared>(); - return false; + throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid cast to native data type"); } -static inline bool canBeNativeType(const IDataType & type) -{ - WhichDataType data_type(type); +/// Cast LLVM value with type to bool +llvm::Value * nativeBoolCast(llvm::IRBuilderBase & b, const DataTypePtr & from_type, llvm::Value * value); - if (data_type.isNullable()) - { - const auto & data_type_nullable = static_cast(type); - return canBeNativeType(*data_type_nullable.getNestedType()); - } +/// Cast LLVM value with type to bool +llvm::Value * nativeBoolCast(llvm::IRBuilderBase & b, const ValueWithType & value_with_type); - return data_type.isNativeInt() || data_type.isNativeUInt() || data_type.isFloat() || data_type.isDate() - || data_type.isDate32() || data_type.isDateTime() || data_type.isEnum(); -} +/// Cast LLVM value with type to specified type +llvm::Value * nativeCast(llvm::IRBuilderBase & b, const DataTypePtr & from_type, llvm::Value * value, const DataTypePtr & to_type); -static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const DataTypePtr & type) -{ - return toNativeType(builder, *type); -} - -static inline llvm::Value * nativeBoolCast(llvm::IRBuilder<> & b, const DataTypePtr & from_type, llvm::Value * value) -{ - if (from_type->isNullable()) - { - auto * inner = nativeBoolCast(b, removeNullable(from_type), b.CreateExtractValue(value, {0})); - return b.CreateAnd(b.CreateNot(b.CreateExtractValue(value, {1})), inner); - } - auto * zero = llvm::Constant::getNullValue(value->getType()); - - if (value->getType()->isIntegerTy()) - return b.CreateICmpNE(value, zero); - if (value->getType()->isFloatingPointTy()) - return b.CreateFCmpUNE(value, zero); - - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot cast non-number {} to bool", from_type->getName()); -} - -static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value, llvm::Type * to_type) -{ - auto * from_type = value->getType(); - - if (from_type == to_type) - return value; - else if (from_type->isIntegerTy() && to_type->isFloatingPointTy()) - return typeIsSigned(*from) ? b.CreateSIToFP(value, to_type) : b.CreateUIToFP(value, to_type); - else if (from_type->isFloatingPointTy() && to_type->isIntegerTy()) - return typeIsSigned(*from) ? b.CreateFPToSI(value, to_type) : b.CreateFPToUI(value, to_type); - else if (from_type->isIntegerTy() && to_type->isIntegerTy()) - return b.CreateIntCast(value, to_type, typeIsSigned(*from)); - else if (from_type->isFloatingPointTy() && to_type->isFloatingPointTy()) - return b.CreateFPCast(value, to_type); - - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot cast {} to requested type", from->getName()); -} +/// Cast LLVM value with type to specified type +llvm::Value * nativeCast(llvm::IRBuilderBase & b, const ValueWithType & value, const DataTypePtr & to_type); template -static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, llvm::Value * value, llvm::Type * to_type) +static inline llvm::Value * nativeCast(llvm::IRBuilderBase & b, llvm::Value * value, const DataTypePtr & to) { - auto * from_type = value->getType(); - - static constexpr bool from_type_is_signed = std::numeric_limits::is_signed; - - if (from_type == to_type) - return value; - else if (from_type->isIntegerTy() && to_type->isFloatingPointTy()) - return from_type_is_signed ? b.CreateSIToFP(value, to_type) : b.CreateUIToFP(value, to_type); - else if (from_type->isFloatingPointTy() && to_type->isIntegerTy()) - return from_type_is_signed ? b.CreateFPToSI(value, to_type) : b.CreateFPToUI(value, to_type); - else if (from_type->isIntegerTy() && to_type->isIntegerTy()) - return b.CreateIntCast(value, to_type, from_type_is_signed); - else if (from_type->isFloatingPointTy() && to_type->isFloatingPointTy()) - return b.CreateFPCast(value, to_type); - - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot cast {} to requested type", TypeName); + auto native_data_type = toNativeDataType(); + return nativeCast(b, native_data_type, value, to); } -static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value, const DataTypePtr & to) -{ - auto * n_to = toNativeType(b, to); - - if (value->getType() == n_to) - { - return value; - } - else if (from->isNullable() && to->isNullable()) - { - auto * inner = nativeCast(b, removeNullable(from), b.CreateExtractValue(value, {0}), to); - return b.CreateInsertValue(inner, b.CreateExtractValue(value, {1}), {1}); - } - else if (from->isNullable()) - { - return nativeCast(b, removeNullable(from), b.CreateExtractValue(value, {0}), to); - } - else if (to->isNullable()) - { - auto * inner = nativeCast(b, from, value, removeNullable(to)); - return b.CreateInsertValue(llvm::Constant::getNullValue(n_to), inner, {0}); - } - - return nativeCast(b, from, value, n_to); -} - -static inline std::pair nativeCastToCommon(llvm::IRBuilder<> & b, const DataTypePtr & lhs_type, llvm::Value * lhs, const DataTypePtr & rhs_type, llvm::Value * rhs) /// NOLINT -{ - llvm::Type * common; - - bool lhs_is_signed = typeIsSigned(*lhs_type); - bool rhs_is_signed = typeIsSigned(*rhs_type); - - if (lhs->getType()->isIntegerTy() && rhs->getType()->isIntegerTy()) - { - /// if one integer has a sign bit, make sure the other does as well. llvm generates optimal code - /// (e.g. uses overflow flag on x86) for (word size + 1)-bit integer operations. - - size_t lhs_bit_width = lhs->getType()->getIntegerBitWidth() + (!lhs_is_signed && rhs_is_signed); - size_t rhs_bit_width = rhs->getType()->getIntegerBitWidth() + (!rhs_is_signed && lhs_is_signed); - - size_t max_bit_width = std::max(lhs_bit_width, rhs_bit_width); - common = b.getIntNTy(static_cast(max_bit_width)); - } - else - { - /// TODO: Check - /// (double, float) or (double, int_N where N <= double's mantissa width) -> double - common = b.getDoubleTy(); - } - - auto * cast_lhs_to_common = nativeCast(b, lhs_type, lhs, common); - auto * cast_rhs_to_common = nativeCast(b, rhs_type, rhs, common); - - return std::make_pair(cast_lhs_to_common, cast_rhs_to_common); -} - -static inline llvm::Constant * getColumnNativeValue(llvm::IRBuilderBase & builder, const DataTypePtr & column_type, const IColumn & column, size_t index) -{ - if (const auto * constant = typeid_cast(&column)) - { - return getColumnNativeValue(builder, column_type, constant->getDataColumn(), 0); - } - - WhichDataType column_data_type(column_type); - - auto * type = toNativeType(builder, column_type); - - if (!type || column.size() <= index) - return nullptr; - - if (column_data_type.isNullable()) - { - const auto & nullable_data_type = assert_cast(*column_type); - const auto & nullable_column = assert_cast(column); - - auto * value = getColumnNativeValue(builder, nullable_data_type.getNestedType(), nullable_column.getNestedColumn(), index); - auto * is_null = llvm::ConstantInt::get(type->getContainedType(1), nullable_column.isNullAt(index)); - - return value ? llvm::ConstantStruct::get(static_cast(type), value, is_null) : nullptr; - } - else if (column_data_type.isFloat32()) - { - return llvm::ConstantFP::get(type, assert_cast &>(column).getElement(index)); - } - else if (column_data_type.isFloat64()) - { - return llvm::ConstantFP::get(type, assert_cast &>(column).getElement(index)); - } - else if (column_data_type.isNativeUInt() || column_data_type.isDate() || column_data_type.isDateTime()) - { - return llvm::ConstantInt::get(type, column.getUInt(index)); - } - else if (column_data_type.isNativeInt() || column_data_type.isEnum() || column_data_type.isDate32()) - { - return llvm::ConstantInt::get(type, column.getInt(index)); - } - - return nullptr; -} +/// Get column value for specified index as LLVM constant +llvm::Constant * getColumnNativeValue(llvm::IRBuilderBase & builder, const DataTypePtr & column_type, const IColumn & column, size_t index); } diff --git a/src/Functions/FunctionBinaryArithmetic.h b/src/Functions/FunctionBinaryArithmetic.h index b205822aab5..50d8abb9bcc 100644 --- a/src/Functions/FunctionBinaryArithmetic.h +++ b/src/Functions/FunctionBinaryArithmetic.h @@ -2046,51 +2046,62 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A } #if USE_EMBEDDED_COMPILER - bool isCompilableImpl(const DataTypes & arguments) const override + bool isCompilableImpl(const DataTypes & arguments, const DataTypePtr & result_type) const override { if (2 != arguments.size()) return false; + if (!canBeNativeType(*arguments[0]) || !canBeNativeType(*arguments[1]) || !canBeNativeType(*result_type)) + return false; + return castBothTypes(arguments[0].get(), arguments[1].get(), [&](const auto & left, const auto & right) { using LeftDataType = std::decay_t; using RightDataType = std::decay_t; - if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) - return false; - else + if constexpr (!std::is_same_v && + !std::is_same_v && + !std::is_same_v && + !std::is_same_v) { using ResultDataType = typename BinaryOperationTraits::ResultDataType; using OpSpec = Op; - return !std::is_same_v && !IsDataTypeDecimal && OpSpec::compilable; + if constexpr (!std::is_same_v && !IsDataTypeDecimal && OpSpec::compilable) + return true; } + return false; }); } - llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, Values values) const override + llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const ValuesWithType & arguments, const DataTypePtr & result_type) const override { - assert(2 == types.size() && 2 == values.size()); + assert(2 == arguments.size()); llvm::Value * result = nullptr; - castBothTypes(types[0].get(), types[1].get(), [&](const auto & left, const auto & right) + castBothTypes(arguments[0].type.get(), arguments[1].type.get(), [&](const auto & left, const auto & right) { using LeftDataType = std::decay_t; using RightDataType = std::decay_t; - if constexpr (!std::is_same_v && !std::is_same_v && !std::is_same_v && !std::is_same_v) + if constexpr (!std::is_same_v && + !std::is_same_v && + !std::is_same_v && + !std::is_same_v) { using ResultDataType = typename BinaryOperationTraits::ResultDataType; using OpSpec = Op; if constexpr (!std::is_same_v && !IsDataTypeDecimal && OpSpec::compilable) { auto & b = static_cast &>(builder); - auto type = std::make_shared(); - auto * lval = nativeCast(b, types[0], values[0], type); - auto * rval = nativeCast(b, types[1], values[1], type); + auto * lval = nativeCast(b, arguments[0], result_type); + auto * rval = nativeCast(b, arguments[1], result_type); result = OpSpec::compile(b, lval, rval, std::is_signed_v); + return true; } } + return false; }); + return result; } #endif diff --git a/src/Functions/FunctionIfBase.h b/src/Functions/FunctionIfBase.h index 4c9ecf78a12..2d5f42a53a0 100644 --- a/src/Functions/FunctionIfBase.h +++ b/src/Functions/FunctionIfBase.h @@ -2,6 +2,7 @@ #include #include +#include #include "config.h" @@ -12,8 +13,11 @@ class FunctionIfBase : public IFunction { #if USE_EMBEDDED_COMPILER public: - bool isCompilableImpl(const DataTypes & types) const override + bool isCompilableImpl(const DataTypes & types, const DataTypePtr & result_type) const override { + if (!canBeNativeType(result_type)) + return false; + /// It's difficult to compare Date and DateTime - cannot use JIT compilation. bool has_date = false; bool has_datetime = false; @@ -31,43 +35,43 @@ public: if (has_date && has_datetime) return false; - if (!isCompilableType(type_removed_nullable)) + if (!canBeNativeType(type_removed_nullable)) return false; } + return true; } - llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, Values values) const override + llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const ValuesWithType & arguments, const DataTypePtr & result_type) const override { auto & b = static_cast &>(builder); - auto return_type = getReturnTypeImpl(types); auto * head = b.GetInsertBlock(); auto * join = llvm::BasicBlock::Create(head->getContext(), "join_block", head->getParent()); std::vector> returns; - for (size_t i = 0; i + 1 < types.size(); i += 2) + for (size_t i = 0; i + 1 < arguments.size(); i += 2) { auto * then = llvm::BasicBlock::Create(head->getContext(), "then_" + std::to_string(i), head->getParent()); auto * next = llvm::BasicBlock::Create(head->getContext(), "next_" + std::to_string(i), head->getParent()); - auto * cond = values[i]; + const auto & cond = arguments[i]; - b.CreateCondBr(nativeBoolCast(b, types[i], cond), then, next); + b.CreateCondBr(nativeBoolCast(b, cond), then, next); b.SetInsertPoint(then); - auto * value = nativeCast(b, types[i + 1], values[i + 1], return_type); + auto * value = nativeCast(b, arguments[i + 1], result_type); returns.emplace_back(b.GetInsertBlock(), value); b.CreateBr(join); b.SetInsertPoint(next); } - auto * else_value = nativeCast(b, types.back(), values.back(), return_type); + auto * else_value = nativeCast(b, arguments.back(), result_type); returns.emplace_back(b.GetInsertBlock(), else_value); b.CreateBr(join); b.SetInsertPoint(join); - auto * phi = b.CreatePHI(toNativeType(b, return_type), static_cast(returns.size())); + auto * phi = b.CreatePHI(toNativeType(b, result_type), static_cast(returns.size())); for (const auto & [block, value] : returns) phi->addIncoming(value, block); diff --git a/src/Functions/FunctionUnaryArithmetic.h b/src/Functions/FunctionUnaryArithmetic.h index 4098d58299c..259dc1c42ba 100644 --- a/src/Functions/FunctionUnaryArithmetic.h +++ b/src/Functions/FunctionUnaryArithmetic.h @@ -477,31 +477,45 @@ public: } #if USE_EMBEDDED_COMPILER - bool isCompilableImpl(const DataTypes & arguments) const override + bool isCompilableImpl(const DataTypes & arguments, const DataTypePtr & result_type) const override { if (1 != arguments.size()) return false; + if (!canBeNativeType(*arguments[0]) || !canBeNativeType(*result_type)) + return false; + return castType(arguments[0].get(), [&](const auto & type) { using DataType = std::decay_t; if constexpr (std::is_same_v || std::is_same_v) + { return false; + } else - return !IsDataTypeDecimal && Op::compilable; + { + using T0 = typename DataType::FieldType; + using T1 = typename Op::ResultType; + if constexpr (!std::is_same_v && !IsDataTypeDecimal && Op::compilable) + return true; + } + + return false; }); } - llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, Values values) const override + llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const ValuesWithType & arguments, const DataTypePtr & result_type) const override { - assert(1 == types.size() && 1 == values.size()); + assert(1 == arguments.size()); llvm::Value * result = nullptr; - castType(types[0].get(), [&](const auto & type) + castType(arguments[0].type.get(), [&](const auto & type) { using DataType = std::decay_t; if constexpr (std::is_same_v || std::is_same_v) + { return false; + } else { using T0 = typename DataType::FieldType; @@ -509,13 +523,16 @@ public: if constexpr (!std::is_same_v && !IsDataTypeDecimal && Op::compilable) { auto & b = static_cast &>(builder); - auto * v = nativeCast(b, types[0], values[0], std::make_shared>()); + auto * v = nativeCast(b, arguments[0], result_type); result = Op::compile(b, v, is_signed_v); + return true; } } + return false; }); + return result; } #endif diff --git a/src/Functions/FunctionsComparison.h b/src/Functions/FunctionsComparison.h index 08bc350c1d4..e3a903008f0 100644 --- a/src/Functions/FunctionsComparison.h +++ b/src/Functions/FunctionsComparison.h @@ -1384,13 +1384,13 @@ public: } #if USE_EMBEDDED_COMPILER - bool isCompilableImpl(const DataTypes & types) const override + bool isCompilableImpl(const DataTypes & arguments, const DataTypePtr & result_type) const override { - if (2 != types.size()) + if (2 != arguments.size()) return false; - WhichDataType data_type_lhs(types[0]); - WhichDataType data_type_rhs(types[1]); + WhichDataType data_type_lhs(arguments[0]); + WhichDataType data_type_rhs(arguments[1]); auto is_big_integer = [](WhichDataType type) { return type.isUInt64() || type.isInt64(); }; @@ -1400,16 +1400,18 @@ public: || (data_type_rhs.isDate() && data_type_lhs.isDateTime())) return false; /// TODO: implement (double, int_N where N > double's mantissa width) - return isCompilableType(types[0]) && isCompilableType(types[1]); + return canBeNativeType(arguments[0]) && canBeNativeType(arguments[1]) && canBeNativeType(result_type); } - llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, Values values) const override + llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const ValuesWithType & arguments, const DataTypePtr & result_type) const override { - assert(2 == types.size() && 2 == values.size()); + assert(2 == arguments.size()); auto & b = static_cast &>(builder); - auto [x, y] = nativeCastToCommon(b, types[0], values[0], types[1], values[1]); - auto * result = CompileOp::compile(b, x, y, typeIsSigned(*types[0]) || typeIsSigned(*types[1])); + auto * x = nativeCast(b, arguments[0], result_type); + auto * y = nativeCast(b, arguments[1], result_type); + auto * result = CompileOp::compile(b, x, y, typeIsSigned(*arguments[0].type) || typeIsSigned(*arguments[1].type)); + return b.CreateSelect(result, b.getInt8(1), b.getInt8(0)); } #endif diff --git a/src/Functions/FunctionsLogical.h b/src/Functions/FunctionsLogical.h index b2a59c51123..a25bffcdd73 100644 --- a/src/Functions/FunctionsLogical.h +++ b/src/Functions/FunctionsLogical.h @@ -184,41 +184,46 @@ public: ColumnPtr getConstantResultForNonConstArguments(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type) const override; #if USE_EMBEDDED_COMPILER - bool isCompilableImpl(const DataTypes &) const override { return useDefaultImplementationForNulls(); } + bool isCompilableImpl(const DataTypes &, const DataTypePtr &) const override { return useDefaultImplementationForNulls(); } - llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, Values values) const override + llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const ValuesWithType & values, const DataTypePtr &) const override { - assert(!types.empty() && !values.empty()); + assert(!values.empty()); auto & b = static_cast &>(builder); if constexpr (!Impl::isSaturable()) { - auto * result = nativeBoolCast(b, types[0], values[0]); - for (size_t i = 1; i < types.size(); ++i) - result = Impl::apply(b, result, nativeBoolCast(b, types[i], values[i])); + auto * result = nativeBoolCast(b, values[0]); + for (size_t i = 1; i < values.size(); ++i) + result = Impl::apply(b, result, nativeBoolCast(b, values[i])); return b.CreateSelect(result, b.getInt8(1), b.getInt8(0)); } + constexpr bool break_on_true = Impl::isSaturatedValue(true); auto * next = b.GetInsertBlock(); auto * stop = llvm::BasicBlock::Create(next->getContext(), "", next->getParent()); b.SetInsertPoint(stop); + auto * phi = b.CreatePHI(b.getInt8Ty(), static_cast(values.size())); - for (size_t i = 0; i < types.size(); ++i) + + for (size_t i = 0; i < values.size(); ++i) { b.SetInsertPoint(next); - auto * value = values[i]; - auto * truth = nativeBoolCast(b, types[i], value); - if (!types[i]->equals(DataTypeUInt8{})) + auto * value = values[i].value; + auto * truth = nativeBoolCast(b, values[i]); + if (!values[i].type->equals(DataTypeUInt8{})) value = b.CreateSelect(truth, b.getInt8(1), b.getInt8(0)); phi->addIncoming(value, b.GetInsertBlock()); - if (i + 1 < types.size()) + if (i + 1 < values.size()) { next = llvm::BasicBlock::Create(next->getContext(), "", next->getParent()); b.CreateCondBr(truth, break_on_true ? stop : next, break_on_true ? next : stop); } } + b.CreateBr(stop); b.SetInsertPoint(stop); + return phi; } #endif @@ -248,12 +253,12 @@ public: ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override; #if USE_EMBEDDED_COMPILER - bool isCompilableImpl(const DataTypes &) const override { return true; } + bool isCompilableImpl(const DataTypes &, const DataTypePtr &) const override { return true; } - llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, Values values) const override + llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const ValuesWithType & values, const DataTypePtr &) const override { auto & b = static_cast &>(builder); - return b.CreateSelect(Impl::apply(b, nativeBoolCast(b, types[0], values[0])), b.getInt8(1), b.getInt8(0)); + return b.CreateSelect(Impl::apply(b, nativeBoolCast(b, values[0])), b.getInt8(1), b.getInt8(0)); } #endif }; diff --git a/src/Functions/IFunction.cpp b/src/Functions/IFunction.cpp index 7563135f21f..4537dacaa39 100644 --- a/src/Functions/IFunction.cpp +++ b/src/Functions/IFunction.cpp @@ -484,59 +484,75 @@ DataTypePtr IFunctionOverloadResolver::getReturnTypeWithoutLowCardinality(const static std::optional removeNullables(const DataTypes & types) { + bool has_nullable = false; for (const auto & type : types) { if (!typeid_cast(type.get())) continue; + + has_nullable = true; + break; + } + + if (has_nullable) + { DataTypes filtered; + filtered.reserve(types.size()); + for (const auto & sub_type : types) filtered.emplace_back(removeNullable(sub_type)); + return filtered; } + return {}; } -bool IFunction::isCompilable(const DataTypes & arguments) const +bool IFunction::isCompilable(const DataTypes & arguments, const DataTypePtr & result_type) const { - if (useDefaultImplementationForNulls()) - if (auto denulled = removeNullables(arguments)) - return isCompilableImpl(*denulled); - return isCompilableImpl(arguments); + if (auto denulled_arguments = removeNullables(arguments)) + return isCompilableImpl(*denulled_arguments, result_type); + + return isCompilableImpl(arguments, result_type); } -llvm::Value * IFunction::compile(llvm::IRBuilderBase & builder, const DataTypes & arguments, Values values) const +llvm::Value * IFunction::compile(llvm::IRBuilderBase & builder, const ValuesWithType & arguments, const DataTypePtr & result_type) const { - auto denulled_arguments = removeNullables(arguments); - if (useDefaultImplementationForNulls() && denulled_arguments) + DataTypes arguments_types; + arguments_types.reserve(arguments.size()); + + for (const auto & argument : arguments) { + arguments_types.push_back(argument.type); + } + + auto denulled_arguments_types = removeNullables(arguments_types); + if (useDefaultImplementationForNulls() && denulled_arguments_types) { auto & b = static_cast &>(builder); - std::vector unwrapped_values; - std::vector is_null_values; + ValuesWithType unwrapped_arguments; + unwrapped_arguments.reserve(arguments.size()); - unwrapped_values.reserve(arguments.size()); - is_null_values.reserve(arguments.size()); + std::vector is_null_values; for (size_t i = 0; i < arguments.size(); ++i) { - auto * value = values[i]; + const auto & argument = arguments[i]; + llvm::Value * unwrapped_value = argument.value; - WhichDataType data_type(arguments[i]); - if (data_type.isNullable()) + if (argument.type->isNullable()) { - unwrapped_values.emplace_back(b.CreateExtractValue(value, {0})); - is_null_values.emplace_back(b.CreateExtractValue(value, {1})); - } - else - { - unwrapped_values.emplace_back(value); + unwrapped_value = b.CreateExtractValue(argument.value, {0}); + is_null_values.emplace_back(b.CreateExtractValue(argument.value, {1})); } + + unwrapped_arguments.emplace_back(unwrapped_value, (*denulled_arguments_types)[i]); } - auto * result = compileImpl(builder, *denulled_arguments, unwrapped_values); + auto * result = compileImpl(builder, unwrapped_arguments, removeNullable(result_type)); - auto * nullable_structure_type = toNativeType(b, makeNullable(getReturnTypeImpl(*denulled_arguments))); + auto * nullable_structure_type = toNativeType(b, makeNullable(getReturnTypeImpl(*denulled_arguments_types))); auto * nullable_structure_value = llvm::Constant::getNullValue(nullable_structure_type); auto * nullable_structure_with_result_value = b.CreateInsertValue(nullable_structure_value, result, {0}); @@ -548,7 +564,7 @@ llvm::Value * IFunction::compile(llvm::IRBuilderBase & builder, const DataTypes return b.CreateInsertValue(nullable_structure_with_result_value, nullable_structure_result_null, {1}); } - return compileImpl(builder, arguments, std::move(values)); + return compileImpl(builder, arguments, result_type); } #endif diff --git a/src/Functions/IFunction.h b/src/Functions/IFunction.h index cf2dcc9617e..433cb61d04e 100644 --- a/src/Functions/IFunction.h +++ b/src/Functions/IFunction.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -121,8 +122,6 @@ private: using ExecutableFunctionPtr = std::shared_ptr; -using Values = std::vector; - /** Function with known arguments and return type (when the specific overload was chosen). * It is also the point where all function-specific properties are known. */ @@ -162,7 +161,7 @@ public: * templates with default arguments is impossible and including LLVM in such a generic header * as this one is a major pain. */ - virtual llvm::Value * compile(llvm::IRBuilderBase & /*builder*/, Values /*values*/) const + virtual llvm::Value * compile(llvm::IRBuilderBase & /*builder*/, const ValuesWithType & /*arguments*/) const { throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName()); } @@ -530,9 +529,9 @@ public: #if USE_EMBEDDED_COMPILER - bool isCompilable(const DataTypes & arguments) const; + bool isCompilable(const DataTypes & arguments, const DataTypePtr & result_type) const; - llvm::Value * compile(llvm::IRBuilderBase &, const DataTypes & arguments, Values values) const; + llvm::Value * compile(llvm::IRBuilderBase & builder, const ValuesWithType & arguments, const DataTypePtr & result_type) const; #endif @@ -540,9 +539,9 @@ protected: #if USE_EMBEDDED_COMPILER - virtual bool isCompilableImpl(const DataTypes &) const { return false; } + virtual bool isCompilableImpl(const DataTypes & /*arguments*/, const DataTypePtr & /*result_type*/) const { return false; } - virtual llvm::Value * compileImpl(llvm::IRBuilderBase &, const DataTypes &, Values) const + virtual llvm::Value * compileImpl(llvm::IRBuilderBase & /*builder*/, const ValuesWithType & /*arguments*/, const DataTypePtr & /*result_type*/) const { throw Exception(ErrorCodes::NOT_IMPLEMENTED, "{} is not JIT-compilable", getName()); } diff --git a/src/Functions/IFunctionAdaptors.h b/src/Functions/IFunctionAdaptors.h index 4ecb45167cc..123fdbc2f50 100644 --- a/src/Functions/IFunctionAdaptors.h +++ b/src/Functions/IFunctionAdaptors.h @@ -55,11 +55,11 @@ public: #if USE_EMBEDDED_COMPILER - bool isCompilable() const override { return function->isCompilable(getArgumentTypes()); } + bool isCompilable() const override { return function->isCompilable(getArgumentTypes(), getResultType()); } - llvm::Value * compile(llvm::IRBuilderBase & builder, Values values) const override + llvm::Value * compile(llvm::IRBuilderBase & builder, const ValuesWithType & compile_arguments) const override { - return function->compile(builder, getArgumentTypes(), std::move(values)); + return function->compile(builder, compile_arguments, getResultType()); } #endif diff --git a/src/Interpreters/ExpressionJIT.cpp b/src/Interpreters/ExpressionJIT.cpp index dfc88e97052..0eacb598fbe 100644 --- a/src/Interpreters/ExpressionJIT.cpp +++ b/src/Interpreters/ExpressionJIT.cpp @@ -160,9 +160,9 @@ public: bool isCompilable() const override { return true; } - llvm::Value * compile(llvm::IRBuilderBase & builder, Values values) const override + llvm::Value * compile(llvm::IRBuilderBase & builder, const ValuesWithType & arguments) const override { - return dag.compile(builder, values); + return dag.compile(builder, arguments).value; } bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & arguments) const override diff --git a/src/Interpreters/JIT/CHJIT.h b/src/Interpreters/JIT/CHJIT.h index cde1129c010..fc883802426 100644 --- a/src/Interpreters/JIT/CHJIT.h +++ b/src/Interpreters/JIT/CHJIT.h @@ -19,14 +19,14 @@ class JITModuleMemoryManager; class JITSymbolResolver; class JITCompiler; -/** Custom jit implementation +/** Custom JIT implementation. * Main use cases: * 1. Compiled functions in module. * 2. Release memory for compiled functions. * * In LLVM library there are 2 main JIT stacks MCJIT and ORCv2. * - * Main reasons for custom implementation vs MCJIT + * Main reasons for custom implementation vs MCJIT. * MCJIT keeps llvm::Module and compiled object code before linking process after module was compiled. * llvm::Module can be removed, but compiled object code cannot be removed. Memory for compiled code * will be release only during MCJIT instance destruction. It is too expensive to create MCJIT diff --git a/src/Interpreters/JIT/CompileDAG.cpp b/src/Interpreters/JIT/CompileDAG.cpp index 2c5c7731150..972cea7bbf8 100644 --- a/src/Interpreters/JIT/CompileDAG.cpp +++ b/src/Interpreters/JIT/CompileDAG.cpp @@ -21,14 +21,14 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; } -llvm::Value * CompileDAG::compile(llvm::IRBuilderBase & builder, Values input_nodes_values) const +ValueWithType CompileDAG::compile(llvm::IRBuilderBase & builder, const ValuesWithType & input_nodes_values) const { assert(input_nodes_values.size() == getInputNodesCount()); llvm::IRBuilder<> & b = static_cast &>(builder); - PaddedPODArray compiled_values; - compiled_values.resize_fill(nodes.size()); + ValuesWithType compiled_values; + compiled_values.resize(nodes.size()); size_t input_nodes_values_index = 0; size_t compiled_values_index = 0; @@ -44,31 +44,26 @@ llvm::Value * CompileDAG::compile(llvm::IRBuilderBase & builder, Values input_no case CompileType::CONSTANT: { auto * native_value = getColumnNativeValue(b, node.result_type, *node.column, 0); - if (!native_value) - throw Exception(ErrorCodes::LOGICAL_ERROR, - "Cannot find native value for constant column with type {}", - node.result_type->getName()); - - compiled_values[compiled_values_index] = native_value; + compiled_values[compiled_values_index] = {native_value, node.result_type}; break; } case CompileType::FUNCTION: { - Values temporary_values; + ValuesWithType temporary_values; temporary_values.reserve(node.arguments.size()); for (auto argument_index : node.arguments) { - assert(compiled_values[argument_index] != nullptr); + assert(compiled_values[argument_index].value != nullptr); temporary_values.emplace_back(compiled_values[argument_index]); } - compiled_values[compiled_values_index] = node.function->compile(builder, temporary_values); + compiled_values[compiled_values_index] = {node.function->compile(builder, temporary_values), node.result_type}; break; } case CompileType::INPUT: { - compiled_values[compiled_values_index] = input_nodes_values[input_nodes_values_index]; + compiled_values[compiled_values_index] = {input_nodes_values[input_nodes_values_index].value, node.result_type}; ++input_nodes_values_index; break; } diff --git a/src/Interpreters/JIT/CompileDAG.h b/src/Interpreters/JIT/CompileDAG.h index a05fa629561..77a02230f55 100644 --- a/src/Interpreters/JIT/CompileDAG.h +++ b/src/Interpreters/JIT/CompileDAG.h @@ -53,7 +53,7 @@ public: std::vector arguments; }; - llvm::Value * compile(llvm::IRBuilderBase & builder, Values input_nodes_values) const; + ValueWithType compile(llvm::IRBuilderBase & builder, const ValuesWithType & input_nodes_values_with_type) const; std::string dump() const; diff --git a/src/Interpreters/JIT/compileFunction.cpp b/src/Interpreters/JIT/compileFunction.cpp index a7233433861..a5a646879bb 100644 --- a/src/Interpreters/JIT/compileFunction.cpp +++ b/src/Interpreters/JIT/compileFunction.cpp @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include namespace @@ -107,7 +109,7 @@ static void compileFunction(llvm::Module & module, const IFunctionBase & functio /// Initialize column row values - Values arguments; + ValuesWithType arguments; arguments.reserve(function_argument_types.size()); for (size_t i = 0; i < function_argument_types.size(); ++i) @@ -120,7 +122,7 @@ static void compileFunction(llvm::Module & module, const IFunctionBase & functio if (!type->isNullable()) { - arguments.emplace_back(column_element_value); + arguments.emplace_back(column_element_value, type); continue; } @@ -128,12 +130,12 @@ static void compileFunction(llvm::Module & module, const IFunctionBase & functio auto * is_null = b.CreateICmpNE(column_is_null_element_value, b.getInt8(0)); auto * nullable_unitialized = llvm::Constant::getNullValue(toNullableType(b, column.data_element_type)); auto * nullable_value = b.CreateInsertValue(b.CreateInsertValue(nullable_unitialized, column_element_value, {0}), is_null, {1}); - arguments.emplace_back(nullable_value); + arguments.emplace_back(nullable_value, type); } /// Compile values for column rows and store compiled value in result column - auto * result = function.compile(b, std::move(arguments)); + auto * result = function.compile(b, arguments); auto * result_column_element_ptr = b.CreateGEP(columns.back().data_element_type, columns.back().data_ptr, counter_phi); if (columns.back().null_data_ptr) @@ -298,24 +300,24 @@ static void compileAddIntoAggregateStatesFunctions(llvm::Module & module, else aggregation_place = places_arg; - std::vector function_arguments_values; + ValuesWithType function_arguments; previous_columns_size = 0; for (const auto & function : functions) { - auto arguments_types = function.function->getArgumentTypes(); + const auto & arguments_types = function.function->getArgumentTypes(); size_t function_arguments_size = arguments_types.size(); for (size_t column_argument_index = 0; column_argument_index < function_arguments_size; ++column_argument_index) { auto & column = columns[previous_columns_size + column_argument_index]; - auto & argument_type = arguments_types[column_argument_index]; + const auto & argument_type = arguments_types[column_argument_index]; auto * column_data_element = b.CreateLoad(column.data_element_type, b.CreateGEP(column.data_element_type, column.data_ptr, counter_phi)); if (!argument_type->isNullable()) { - function_arguments_values.push_back(column_data_element); + function_arguments.emplace_back(column_data_element, argument_type); continue; } @@ -324,16 +326,16 @@ static void compileAddIntoAggregateStatesFunctions(llvm::Module & module, auto * nullable_unitialized = llvm::Constant::getNullValue(toNullableType(b, column.data_element_type)); auto * first_insert = b.CreateInsertValue(nullable_unitialized, column_data_element, {0}); auto * nullable_value = b.CreateInsertValue(first_insert, is_null, {1}); - function_arguments_values.push_back(nullable_value); + function_arguments.emplace_back(nullable_value, argument_type); } size_t aggregate_function_offset = function.aggregate_data_offset; auto * aggregation_place_with_offset = b.CreateConstInBoundsGEP1_64(b.getInt8Ty(), aggregation_place, aggregate_function_offset); const auto * aggregate_function_ptr = function.function; - aggregate_function_ptr->compileAdd(b, aggregation_place_with_offset, arguments_types, function_arguments_values); + aggregate_function_ptr->compileAdd(b, aggregation_place_with_offset, function_arguments); - function_arguments_values.clear(); + function_arguments.clear(); previous_columns_size += function_arguments_size; } From fcc149a9cdfc3887d3be0bb5efd65517196f9801 Mon Sep 17 00:00:00 2001 From: Maksim Kita Date: Sat, 3 Jun 2023 21:31:26 +0300 Subject: [PATCH 2/5] Added tests --- ...1_jit_functions_comparison_crash.reference | 0 .../02771_jit_functions_comparison_crash.sql | 36 +++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 tests/queries/0_stateless/02771_jit_functions_comparison_crash.reference create mode 100644 tests/queries/0_stateless/02771_jit_functions_comparison_crash.sql diff --git a/tests/queries/0_stateless/02771_jit_functions_comparison_crash.reference b/tests/queries/0_stateless/02771_jit_functions_comparison_crash.reference new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/queries/0_stateless/02771_jit_functions_comparison_crash.sql b/tests/queries/0_stateless/02771_jit_functions_comparison_crash.sql new file mode 100644 index 00000000000..e02f1a3382d --- /dev/null +++ b/tests/queries/0_stateless/02771_jit_functions_comparison_crash.sql @@ -0,0 +1,36 @@ +SET compile_expressions = 1; +SET min_count_to_compile_expression = 0; + +DROP TABLE IF EXISTS test_table_1; +CREATE TABLE test_table_1 +( + pkey UInt32, + c8 UInt32, + c9 String, + c10 Float32, + c11 String +) ENGINE = MergeTree ORDER BY pkey; + +DROP TABLE IF EXISTS test_table_2; +CREATE TABLE test_table_2 +( + vkey UInt32, + pkey UInt32, + c15 UInt32 +) ENGINE = MergeTree ORDER BY vkey; + +WITH test_cte AS +( + SELECT + ref_10.c11 as c_2_c2350_1, + ref_9.c9 as c_2_c2351_2 + FROM + test_table_1 as ref_9 + RIGHT OUTER JOIN test_table_1 as ref_10 ON (ref_9.c11 = ref_10.c9) + INNER JOIN test_table_2 as ref_11 ON (ref_10.c8 = ref_11.vkey) + WHERE ((ref_10.pkey + ref_11.pkey) BETWEEN ref_11.vkey AND (CASE WHEN (-30.87 >= ref_9.c10) THEN ref_11.c15 ELSE ref_11.pkey END)) +) +SELECT ref_13.c_2_c2350_1 as c_2_c2357_3 FROM test_cte as ref_13 WHERE (ref_13.c_2_c2351_2) in (select ref_14.c_2_c2351_2 as c_5_c2352_0 FROM test_cte as ref_14); + +DROP TABLE test_table_1; +DROP TABLE test_table_2; From 6e26fde707ff68a864d9799ea8af2b9ef9a049ab Mon Sep 17 00:00:00 2001 From: Maksim Kita Date: Sun, 4 Jun 2023 15:01:48 +0300 Subject: [PATCH 3/5] Updated tests --- src/DataTypes/Native.cpp | 1 + src/Functions/FunctionBinaryArithmetic.h | 6 ++++++ src/Functions/IFunction.cpp | 3 +-- src/Interpreters/JIT/CompileDAG.cpp | 5 ----- tests/queries/0_stateless/02772_jit_date_time_add.reference | 1 + tests/queries/0_stateless/02772_jit_date_time_add.sql | 6 ++++++ 6 files changed, 15 insertions(+), 7 deletions(-) create mode 100644 tests/queries/0_stateless/02772_jit_date_time_add.reference create mode 100644 tests/queries/0_stateless/02772_jit_date_time_add.sql diff --git a/src/DataTypes/Native.cpp b/src/DataTypes/Native.cpp index acbd70ba04f..6f1ea851dce 100644 --- a/src/DataTypes/Native.cpp +++ b/src/DataTypes/Native.cpp @@ -12,6 +12,7 @@ namespace DB namespace ErrorCodes { extern const int NOT_IMPLEMENTED; + extern const int LOGICAL_ERROR; } bool typeIsSigned(const IDataType & type) diff --git a/src/Functions/FunctionBinaryArithmetic.h b/src/Functions/FunctionBinaryArithmetic.h index 50d8abb9bcc..c699da4eaf6 100644 --- a/src/Functions/FunctionBinaryArithmetic.h +++ b/src/Functions/FunctionBinaryArithmetic.h @@ -2054,6 +2054,12 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A if (!canBeNativeType(*arguments[0]) || !canBeNativeType(*arguments[1]) || !canBeNativeType(*result_type)) return false; + WhichDataType data_type_lhs(arguments[0]); + WhichDataType data_type_rhs(arguments[1]); + if ((data_type_lhs.isDateOrDate32() || data_type_lhs.isDateTime()) || + (data_type_rhs.isDateOrDate32() || data_type_rhs.isDateTime())) + return false; + return castBothTypes(arguments[0].get(), arguments[1].get(), [&](const auto & left, const auto & right) { using LeftDataType = std::decay_t; diff --git a/src/Functions/IFunction.cpp b/src/Functions/IFunction.cpp index 4537dacaa39..650b54d9a37 100644 --- a/src/Functions/IFunction.cpp +++ b/src/Functions/IFunction.cpp @@ -522,9 +522,8 @@ llvm::Value * IFunction::compile(llvm::IRBuilderBase & builder, const ValuesWith DataTypes arguments_types; arguments_types.reserve(arguments.size()); - for (const auto & argument : arguments) { + for (const auto & argument : arguments) arguments_types.push_back(argument.type); - } auto denulled_arguments_types = removeNullables(arguments_types); if (useDefaultImplementationForNulls() && denulled_arguments_types) diff --git a/src/Interpreters/JIT/CompileDAG.cpp b/src/Interpreters/JIT/CompileDAG.cpp index 972cea7bbf8..6da17fb4c67 100644 --- a/src/Interpreters/JIT/CompileDAG.cpp +++ b/src/Interpreters/JIT/CompileDAG.cpp @@ -16,11 +16,6 @@ namespace DB { -namespace ErrorCodes -{ - extern const int LOGICAL_ERROR; -} - ValueWithType CompileDAG::compile(llvm::IRBuilderBase & builder, const ValuesWithType & input_nodes_values) const { assert(input_nodes_values.size() == getInputNodesCount()); diff --git a/tests/queries/0_stateless/02772_jit_date_time_add.reference b/tests/queries/0_stateless/02772_jit_date_time_add.reference new file mode 100644 index 00000000000..dec7d2fabd2 --- /dev/null +++ b/tests/queries/0_stateless/02772_jit_date_time_add.reference @@ -0,0 +1 @@ +\N diff --git a/tests/queries/0_stateless/02772_jit_date_time_add.sql b/tests/queries/0_stateless/02772_jit_date_time_add.sql new file mode 100644 index 00000000000..61028ac4172 --- /dev/null +++ b/tests/queries/0_stateless/02772_jit_date_time_add.sql @@ -0,0 +1,6 @@ +SET compile_expressions = 1; +SET min_count_to_compile_expression = 0; + +SELECT DISTINCT result FROM (SELECT toStartOfFifteenMinutes(toDateTime(toStartOfFifteenMinutes(toDateTime(1000.0001220703125) + (number * 65536))) + (number * 9223372036854775807)) AS result FROM system.numbers LIMIT 1048576) ORDER BY result DESC NULLS FIRST FORMAT Null; -- { serverError 407 } +SELECT DISTINCT result FROM (SELECT toStartOfFifteenMinutes(toDateTime(toStartOfFifteenMinutes(toDateTime(1000.0001220703125) + (number * 65536))) + toInt64(number * 9223372036854775807)) AS result FROM system.numbers LIMIT 1048576) ORDER BY result DESC NULLS FIRST FORMAT Null; +SELECT round(round(round(round(round(100)), round(round(round(round(NULL), round(65535)), toTypeName(now() + 9223372036854775807) LIKE 'DateTime%DateTime%DateTime%DateTime%', round(-2)), 255), round(NULL)))); From aa28a1f25915fb7225f6a596a83a4dfa9f8a5f6b Mon Sep 17 00:00:00 2001 From: Maksim Kita Date: Tue, 13 Jun 2023 11:44:15 +0300 Subject: [PATCH 4/5] Fixed tests --- src/Functions/FunctionsComparison.h | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Functions/FunctionsComparison.h b/src/Functions/FunctionsComparison.h index e3a903008f0..767bdf2c823 100644 --- a/src/Functions/FunctionsComparison.h +++ b/src/Functions/FunctionsComparison.h @@ -1400,16 +1400,20 @@ public: || (data_type_rhs.isDate() && data_type_lhs.isDateTime())) return false; /// TODO: implement (double, int_N where N > double's mantissa width) - return canBeNativeType(arguments[0]) && canBeNativeType(arguments[1]) && canBeNativeType(result_type); + DataTypePtr common_type = getLeastSupertype(arguments); + return canBeNativeType(arguments[0]) && canBeNativeType(arguments[1]) && canBeNativeType(result_type) && canBeNativeType(common_type); } - llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const ValuesWithType & arguments, const DataTypePtr & result_type) const override + llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const ValuesWithType & arguments, const DataTypePtr &) const override { assert(2 == arguments.size()); + DataTypePtr common_type = getLeastSupertype(DataTypes{arguments[0].type, arguments[1].type}); + auto & b = static_cast &>(builder); - auto * x = nativeCast(b, arguments[0], result_type); - auto * y = nativeCast(b, arguments[1], result_type); + auto * x = nativeCast(b, arguments[0], common_type); + auto * y = nativeCast(b, arguments[1], common_type); + auto * result = CompileOp::compile(b, x, y, typeIsSigned(*arguments[0].type) || typeIsSigned(*arguments[1].type)); return b.CreateSelect(result, b.getInt8(1), b.getInt8(0)); From 7e5017dd31a4146863413b457de0e7af296525c4 Mon Sep 17 00:00:00 2001 From: Maksim Kita Date: Tue, 20 Jun 2023 11:39:21 +0300 Subject: [PATCH 5/5] Fixed tests --- base/base/find_symbols.h | 1 + programs/copier/ShardPartitionPiece.h | 2 + src/Access/Common/AccessFlags.h | 1 + .../assertProcessUserMatchesDataOwner.h | 2 + src/Core/QualifiedTableName.h | 2 +- src/Functions/FunctionsComparison.h | 37 ------------------- src/Interpreters/JIT/compileFunction.cpp | 8 ++-- .../ParallelReplicasReadingCoordinator.cpp | 2 +- src/Storages/MergeTree/RangesInDataPart.cpp | 2 +- 9 files changed, 13 insertions(+), 44 deletions(-) diff --git a/base/base/find_symbols.h b/base/base/find_symbols.h index a8747ecc9b7..83232669c04 100644 --- a/base/base/find_symbols.h +++ b/base/base/find_symbols.h @@ -2,6 +2,7 @@ #include #include +#include #if defined(__SSE2__) #include diff --git a/programs/copier/ShardPartitionPiece.h b/programs/copier/ShardPartitionPiece.h index aba378d466d..453364c0fc8 100644 --- a/programs/copier/ShardPartitionPiece.h +++ b/programs/copier/ShardPartitionPiece.h @@ -2,6 +2,8 @@ #include +#include + namespace DB { diff --git a/src/Access/Common/AccessFlags.h b/src/Access/Common/AccessFlags.h index 270ee1c0045..c9672da7d92 100644 --- a/src/Access/Common/AccessFlags.h +++ b/src/Access/Common/AccessFlags.h @@ -4,6 +4,7 @@ #include #include #include +#include namespace DB diff --git a/src/Common/assertProcessUserMatchesDataOwner.h b/src/Common/assertProcessUserMatchesDataOwner.h index b31d795da71..7a6c5d36335 100644 --- a/src/Common/assertProcessUserMatchesDataOwner.h +++ b/src/Common/assertProcessUserMatchesDataOwner.h @@ -1,5 +1,7 @@ #pragma once + #include +#include namespace DB { diff --git a/src/Core/QualifiedTableName.h b/src/Core/QualifiedTableName.h index 3310130629d..bf05bd59caf 100644 --- a/src/Core/QualifiedTableName.h +++ b/src/Core/QualifiedTableName.h @@ -127,7 +127,7 @@ namespace fmt template auto format(const DB::QualifiedTableName & name, FormatContext & ctx) { - return format_to(ctx.out(), "{}.{}", DB::backQuoteIfNeed(name.database), DB::backQuoteIfNeed(name.table)); + return fmt::format_to(ctx.out(), "{}.{}", DB::backQuoteIfNeed(name.database), DB::backQuoteIfNeed(name.table)); } }; } diff --git a/src/Functions/FunctionsComparison.h b/src/Functions/FunctionsComparison.h index 767bdf2c823..66269f72866 100644 --- a/src/Functions/FunctionsComparison.h +++ b/src/Functions/FunctionsComparison.h @@ -1382,43 +1382,6 @@ public: return executeGeneric(col_with_type_and_name_left, col_with_type_and_name_right); } } - -#if USE_EMBEDDED_COMPILER - bool isCompilableImpl(const DataTypes & arguments, const DataTypePtr & result_type) const override - { - if (2 != arguments.size()) - return false; - - WhichDataType data_type_lhs(arguments[0]); - WhichDataType data_type_rhs(arguments[1]); - - auto is_big_integer = [](WhichDataType type) { return type.isUInt64() || type.isInt64(); }; - - if ((is_big_integer(data_type_lhs) && data_type_rhs.isFloat()) - || (is_big_integer(data_type_rhs) && data_type_lhs.isFloat()) - || (data_type_lhs.isDate() && data_type_rhs.isDateTime()) - || (data_type_rhs.isDate() && data_type_lhs.isDateTime())) - return false; /// TODO: implement (double, int_N where N > double's mantissa width) - - DataTypePtr common_type = getLeastSupertype(arguments); - return canBeNativeType(arguments[0]) && canBeNativeType(arguments[1]) && canBeNativeType(result_type) && canBeNativeType(common_type); - } - - llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const ValuesWithType & arguments, const DataTypePtr &) const override - { - assert(2 == arguments.size()); - - DataTypePtr common_type = getLeastSupertype(DataTypes{arguments[0].type, arguments[1].type}); - - auto & b = static_cast &>(builder); - auto * x = nativeCast(b, arguments[0], common_type); - auto * y = nativeCast(b, arguments[1], common_type); - - auto * result = CompileOp::compile(b, x, y, typeIsSigned(*arguments[0].type) || typeIsSigned(*arguments[1].type)); - - return b.CreateSelect(result, b.getInt8(1), b.getInt8(0)); - } -#endif }; } diff --git a/src/Interpreters/JIT/compileFunction.cpp b/src/Interpreters/JIT/compileFunction.cpp index a5a646879bb..fb8dec665b4 100644 --- a/src/Interpreters/JIT/compileFunction.cpp +++ b/src/Interpreters/JIT/compileFunction.cpp @@ -118,7 +118,7 @@ static void compileFunction(llvm::Module & module, const IFunctionBase & functio const auto & type = function_argument_types[i]; auto * column_data_ptr = column.data_ptr; - auto * column_element_value = b.CreateLoad(column.data_element_type, b.CreateGEP(column.data_element_type, column_data_ptr, counter_phi)); + auto * column_element_value = b.CreateLoad(column.data_element_type, b.CreateInBoundsGEP(column.data_element_type, column_data_ptr, counter_phi)); if (!type->isNullable()) { @@ -126,7 +126,7 @@ static void compileFunction(llvm::Module & module, const IFunctionBase & functio continue; } - auto * column_is_null_element_value = b.CreateLoad(b.getInt8Ty(), b.CreateGEP(b.getInt8Ty(), column.null_data_ptr, counter_phi)); + auto * column_is_null_element_value = b.CreateLoad(b.getInt8Ty(), b.CreateInBoundsGEP(b.getInt8Ty(), column.null_data_ptr, counter_phi)); auto * is_null = b.CreateICmpNE(column_is_null_element_value, b.getInt8(0)); auto * nullable_unitialized = llvm::Constant::getNullValue(toNullableType(b, column.data_element_type)); auto * nullable_value = b.CreateInsertValue(b.CreateInsertValue(nullable_unitialized, column_element_value, {0}), is_null, {1}); @@ -136,12 +136,12 @@ static void compileFunction(llvm::Module & module, const IFunctionBase & functio /// Compile values for column rows and store compiled value in result column auto * result = function.compile(b, arguments); - auto * result_column_element_ptr = b.CreateGEP(columns.back().data_element_type, columns.back().data_ptr, counter_phi); + auto * result_column_element_ptr = b.CreateInBoundsGEP(columns.back().data_element_type, columns.back().data_ptr, counter_phi); if (columns.back().null_data_ptr) { b.CreateStore(b.CreateExtractValue(result, {0}), result_column_element_ptr); - auto * result_column_is_null_element_ptr = b.CreateGEP(b.getInt8Ty(), columns.back().null_data_ptr, counter_phi); + auto * result_column_is_null_element_ptr = b.CreateInBoundsGEP(b.getInt8Ty(), columns.back().null_data_ptr, counter_phi); auto * is_result_column_element_null = b.CreateSelect(b.CreateExtractValue(result, {1}), b.getInt8(1), b.getInt8(0)); b.CreateStore(is_result_column_element_null, result_column_is_null_element_ptr); } diff --git a/src/Storages/MergeTree/ParallelReplicasReadingCoordinator.cpp b/src/Storages/MergeTree/ParallelReplicasReadingCoordinator.cpp index bb044d15ba2..2814d13cff0 100644 --- a/src/Storages/MergeTree/ParallelReplicasReadingCoordinator.cpp +++ b/src/Storages/MergeTree/ParallelReplicasReadingCoordinator.cpp @@ -43,7 +43,7 @@ struct fmt::formatter template auto format(const DB::Part & part, FormatContext & ctx) { - return format_to(ctx.out(), "{} in replicas [{}]", part.description.describe(), fmt::join(part.replicas, ", ")); + return fmt::format_to(ctx.out(), "{} in replicas [{}]", part.description.describe(), fmt::join(part.replicas, ", ")); } }; diff --git a/src/Storages/MergeTree/RangesInDataPart.cpp b/src/Storages/MergeTree/RangesInDataPart.cpp index 6203f9f7483..e64e9ab0b2a 100644 --- a/src/Storages/MergeTree/RangesInDataPart.cpp +++ b/src/Storages/MergeTree/RangesInDataPart.cpp @@ -15,7 +15,7 @@ struct fmt::formatter template auto format(const DB::RangesInDataPartDescription & range, FormatContext & ctx) { - return format_to(ctx.out(), "{}", range.describe()); + return fmt::format_to(ctx.out(), "{}", range.describe()); } };