diff --git a/src/Columns/MaskOperations.cpp b/src/Columns/MaskOperations.cpp index c5cc8777f7d..cf585262086 100644 --- a/src/Columns/MaskOperations.cpp +++ b/src/Columns/MaskOperations.cpp @@ -69,6 +69,7 @@ INSTANTIATE(Decimal128) INSTANTIATE(Decimal256) INSTANTIATE(DateTime64) INSTANTIATE(char *) +INSTANTIATE(UUID) #undef INSTANTIATE @@ -251,14 +252,18 @@ void executeColumnIfNeeded(ColumnWithTypeAndName & column) column = typeid_cast(column_function)->reduce(); } -bool checkArgumentsForColumnFunction(const ColumnsWithTypeAndName & arguments) + +int checkShirtCircuitArguments(const ColumnsWithTypeAndName & arguments) { - for (const auto & arg : arguments) + int last_short_circuit_argument_index = -1; + for (size_t i = 0; i != arguments.size(); ++i) { - if (checkAndGetColumn(*arg.column)) - return true; + const auto * column_func = checkAndGetColumn(*arguments[i].column); + if (column_func && column_func->isShortCircuitArgument()) + last_short_circuit_argument_index = i; } - return false; + + return last_short_circuit_argument_index; } } diff --git a/src/Columns/MaskOperations.h b/src/Columns/MaskOperations.h index f4d49a6c65b..ee005e11f24 100644 --- a/src/Columns/MaskOperations.h +++ b/src/Columns/MaskOperations.h @@ -23,6 +23,6 @@ void maskedExecute(ColumnWithTypeAndName & column, const PaddedPODArray & void executeColumnIfNeeded(ColumnWithTypeAndName & column); -bool checkArgumentsForColumnFunction(const ColumnsWithTypeAndName & arguments); +int checkShirtCircuitArguments(const ColumnsWithTypeAndName & arguments); } diff --git a/src/Functions/FunctionUnixTimestamp64.h b/src/Functions/FunctionUnixTimestamp64.h index d292b14aabb..ad14f05663f 100644 --- a/src/Functions/FunctionUnixTimestamp64.h +++ b/src/Functions/FunctionUnixTimestamp64.h @@ -100,6 +100,7 @@ public: String getName() const override { return name; } size_t getNumberOfArguments() const override { return 0; } bool isVariadic() const override { return true; } + bool isSuitableForShortCircuitArgumentsExecution() const override { return false; } bool useDefaultImplementationForConstants() const override { return true; } DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override diff --git a/src/Functions/FunctionsLogical.cpp b/src/Functions/FunctionsLogical.cpp index d55dc3cc758..24d3aa36447 100644 --- a/src/Functions/FunctionsLogical.cpp +++ b/src/Functions/FunctionsLogical.cpp @@ -514,22 +514,20 @@ void FunctionAnyArityLogical::executeShortCircuitArguments(ColumnsWi if (Name::name != NameAnd::name && Name::name != NameOr::name) throw Exception("Function " + getName() + " doesn't support short circuit execution", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - if (!checkArgumentsForColumnFunction(arguments)) + int last_short_circuit_argument_index = checkShirtCircuitArguments(arguments); + if (last_short_circuit_argument_index < 0) return; bool reverse = Name::name != NameAnd::name; UInt8 null_value = Name::name == NameAnd::name ? 1 : 0; UInt8 value_for_mask_expanding = Name::name == NameAnd::name ? 0 : 1; executeColumnIfNeeded(arguments[0]); + IColumn::Filter mask; - getMaskFromColumn(arguments[0].column, mask, reverse, nullptr, null_value); - - for (size_t i = 1; i < arguments.size(); ++i) + for (int i = 1; i <= last_short_circuit_argument_index; ++i) { - if (isColumnFunction(*arguments[i].column)) - maskedExecute(arguments[i], mask, false, &value_for_mask_expanding); - - getMaskFromColumn(arguments[i].column, mask, reverse, nullptr, null_value); + getMaskFromColumn(arguments[i - 1].column, mask, reverse, nullptr, null_value); + maskedExecute(arguments[i], mask, false, &value_for_mask_expanding); } } diff --git a/src/Functions/if.cpp b/src/Functions/if.cpp index 9e208b61e43..deaea68eb85 100644 --- a/src/Functions/if.cpp +++ b/src/Functions/if.cpp @@ -923,11 +923,12 @@ public: void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) const override { - if (!checkArgumentsForColumnFunction(arguments)) + int last_short_circuit_argument_index = checkShirtCircuitArguments(arguments); + if (last_short_circuit_argument_index < 0) return; executeColumnIfNeeded(arguments[0]); - if (isColumnFunction(*arguments[1].column) || isColumnFunction(*arguments[2].column)) + if (last_short_circuit_argument_index > 0) { IColumn::Filter mask; getMaskFromColumn(arguments[0].column, mask); diff --git a/src/Functions/multiIf.cpp b/src/Functions/multiIf.cpp index 6df44d7d4fb..6de2f3765a9 100644 --- a/src/Functions/multiIf.cpp +++ b/src/Functions/multiIf.cpp @@ -111,7 +111,8 @@ public: void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) const override { - if (!checkArgumentsForColumnFunction(arguments)) + int last_short_circuit_argument_index = checkShirtCircuitArguments(arguments); + if (last_short_circuit_argument_index < 0) return; executeColumnIfNeeded(arguments[0]); @@ -119,24 +120,19 @@ public: IColumn::Filter mask_disjunctions = IColumn::Filter(arguments[0].column->size(), 0); UInt8 default_value_for_mask_expanding = 0; - size_t i = 1; - while (i < arguments.size()) + int i = 1; + while (i <= last_short_circuit_argument_index) { getMaskFromColumn(arguments[i - 1].column, current_mask); - disjunctionMasks(mask_disjunctions, current_mask); - if (isColumnFunction(*arguments[i].column)) - maskedExecute(arguments[i], current_mask); + maskedExecute(arguments[i], current_mask); ++i; + if (i > last_short_circuit_argument_index) + break; - if (isColumnFunction(*arguments[i].column)) - { - if (i < arguments.size() - 1) - maskedExecute(arguments[i], mask_disjunctions, true, &default_value_for_mask_expanding); - else - maskedExecute(arguments[i], mask_disjunctions, true); - } - + disjunctionMasks(mask_disjunctions, current_mask); + UInt8 * default_value_ptr = i + 1 == int(arguments.size()) ? nullptr: &default_value_for_mask_expanding; + maskedExecute(arguments[i], mask_disjunctions, true, default_value_ptr); ++i; } } diff --git a/src/Interpreters/ActionsDAG.h b/src/Interpreters/ActionsDAG.h index 77fa3ed7be3..be06214924e 100644 --- a/src/Interpreters/ActionsDAG.h +++ b/src/Interpreters/ActionsDAG.h @@ -67,6 +67,13 @@ public: using NodeRawPtrs = std::vector; using NodeRawConstPtrs = std::vector; + enum class LazyExecution + { + DISABLED, + ENABLED, + FORCE_ENABLED, + }; + struct Node { NodeRawConstPtrs children; @@ -90,9 +97,9 @@ public: /// For COLUMN node and propagated constants. ColumnPtr column; - void toTree(JSONBuilder::JSONMap & map) const; + LazyExecution lazy_execution = LazyExecution::DISABLED; - bool is_lazy_executed = false; + void toTree(JSONBuilder::JSONMap & map) const; }; /// NOTE: std::list is an implementation detail. diff --git a/src/Interpreters/ExpressionActions.cpp b/src/Interpreters/ExpressionActions.cpp index fb5c3aa8039..edaa932d4fa 100644 --- a/src/Interpreters/ExpressionActions.cpp +++ b/src/Interpreters/ExpressionActions.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -69,38 +70,26 @@ ExpressionActionsPtr ExpressionActions::clone() const return std::make_shared(*this); } -bool ExpressionActions::rewriteShortCircuitArguments(const ActionsDAG::NodeRawConstPtrs & children, const std::unordered_map & need_outside, bool force_rewrite) +void ExpressionActions::rewriteShortCircuitArguments(const ActionsDAG::NodeRawConstPtrs & children, const std::unordered_map & need_outside, bool force_rewrite) { - bool have_rewritten_child = false; for (const auto * child : children) { - if (!need_outside.contains(child) || need_outside.at(child)) + if (!need_outside.contains(child) || need_outside.at(child) || child->lazy_execution != ActionsDAG::LazyExecution::DISABLED) continue; - if (child->is_lazy_executed) - { - have_rewritten_child = true; - continue; - } - switch (child->type) { case ActionsDAG::ActionType::FUNCTION: - if (rewriteShortCircuitArguments(child->children, need_outside, force_rewrite) || child->function_base->isSuitableForShortCircuitArgumentsExecution() || force_rewrite) - { - const_cast(child)->is_lazy_executed = true; - have_rewritten_child = true; - } + rewriteShortCircuitArguments(child->children, need_outside, force_rewrite); + const_cast(child)->lazy_execution = force_rewrite ? ActionsDAG::LazyExecution::FORCE_ENABLED : ActionsDAG::LazyExecution::ENABLED; break; case ActionsDAG::ActionType::ALIAS: - have_rewritten_child |= rewriteShortCircuitArguments(child->children, need_outside, force_rewrite); + rewriteShortCircuitArguments(child->children, need_outside, force_rewrite); break; default: break; } } - - return have_rewritten_child; } @@ -426,7 +415,9 @@ static void executeAction(const ExpressionActions::Action & action, ExecutionCon arguments[i] = columns[action.arguments[i].pos]; } - if (action.node->is_lazy_executed) + if (action.node->lazy_execution == ActionsDAG::LazyExecution::FORCE_ENABLED + || (action.node->lazy_execution == ActionsDAG::LazyExecution::ENABLED + && (action.node->function_base->isSuitableForShortCircuitArgumentsExecution() || checkShirtCircuitArguments(arguments) >= 0))) res_column.column = ColumnFunction::create(num_rows, action.node->function_base, std::move(arguments), true); else { diff --git a/src/Interpreters/ExpressionActions.h b/src/Interpreters/ExpressionActions.h index 419006c572b..08fc7e73122 100644 --- a/src/Interpreters/ExpressionActions.h +++ b/src/Interpreters/ExpressionActions.h @@ -133,7 +133,7 @@ private: void checkLimits(const ColumnsWithTypeAndName & columns) const; void linearizeActions(); - bool rewriteShortCircuitArguments( + void rewriteShortCircuitArguments( const ActionsDAG::NodeRawConstPtrs & children, const std::unordered_map & need_outside, bool force_rewrite); void rewriteArgumentsForShortCircuitFunctions( diff --git a/src/Interpreters/ExpressionJIT.cpp b/src/Interpreters/ExpressionJIT.cpp index 497aa56ab13..f98667a7fbe 100644 --- a/src/Interpreters/ExpressionJIT.cpp +++ b/src/Interpreters/ExpressionJIT.cpp @@ -166,6 +166,8 @@ public: return dag.compile(builder, values); } + bool isSuitableForShortCircuitArgumentsExecution() const override { return true; } + String getName() const override { return name; } const DataTypes & getArgumentTypes() const override { return argument_types; }