diff --git a/src/Columns/ColumnFunction.h b/src/Columns/ColumnFunction.h index 8e39551676c..2592dc01f98 100644 --- a/src/Columns/ColumnFunction.h +++ b/src/Columns/ColumnFunction.h @@ -24,7 +24,12 @@ class ColumnFunction final : public COWHelper private: friend class COWHelper; - ColumnFunction(size_t size, FunctionBasePtr function_, const ColumnsWithTypeAndName & columns_to_capture, bool is_short_circuit_argument_ = false, bool is_function_compiled_ = false); + ColumnFunction( + size_t size, + FunctionBasePtr function_, + const ColumnsWithTypeAndName & columns_to_capture, + bool is_short_circuit_argument_ = false, + bool is_function_compiled_ = false); public: const char * getFamilyName() const override { return "Function"; } diff --git a/src/Columns/MaskOperations.cpp b/src/Columns/MaskOperations.cpp index 9499185da30..1641bdf5a4c 100644 --- a/src/Columns/MaskOperations.cpp +++ b/src/Columns/MaskOperations.cpp @@ -293,7 +293,7 @@ void executeColumnIfNeeded(ColumnWithTypeAndName & column, bool empty) column.column = column_function->getResultType()->createColumn(); } -int checkShirtCircuitArguments(const ColumnsWithTypeAndName & arguments) +int checkShortCircuitArguments(const ColumnsWithTypeAndName & arguments) { int last_short_circuit_argument_index = -1; for (size_t i = 0; i != arguments.size(); ++i) diff --git a/src/Columns/MaskOperations.h b/src/Columns/MaskOperations.h index bd6c5e8fe2c..e43b4588258 100644 --- a/src/Columns/MaskOperations.h +++ b/src/Columns/MaskOperations.h @@ -66,7 +66,7 @@ void executeColumnIfNeeded(ColumnWithTypeAndName & column, bool empty = false); /// Check if arguments contain lazy executed argument. If contain, return index of the last one, /// otherwise return -1. -int checkShirtCircuitArguments(const ColumnsWithTypeAndName & arguments); +int checkShortCircuitArguments(const ColumnsWithTypeAndName & arguments); void copyMask(const PaddedPODArray & from, PaddedPODArray & to); diff --git a/src/Functions/FunctionsLogical.cpp b/src/Functions/FunctionsLogical.cpp index f427deced3a..87a2ecd4c57 100644 --- a/src/Functions/FunctionsLogical.cpp +++ b/src/Functions/FunctionsLogical.cpp @@ -609,7 +609,7 @@ ColumnPtr FunctionAnyArityLogical::executeImpl( ColumnsWithTypeAndName arguments = std::move(args); /// Special implementation for short-circuit arguments. - if (checkShirtCircuitArguments(arguments) != -1) + if (checkShortCircuitArguments(arguments) != -1) return executeShortCircuit(arguments, result_type); ColumnRawPtrs args_in; diff --git a/src/Functions/if.cpp b/src/Functions/if.cpp index 953aff3568e..6841098ebcf 100644 --- a/src/Functions/if.cpp +++ b/src/Functions/if.cpp @@ -969,7 +969,7 @@ private: static void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) { - int last_short_circuit_argument_index = checkShirtCircuitArguments(arguments); + int last_short_circuit_argument_index = checkShortCircuitArguments(arguments); if (last_short_circuit_argument_index == -1) return; diff --git a/src/Functions/multiIf.cpp b/src/Functions/multiIf.cpp index 3e5242d5f9b..070a7c2f05e 100644 --- a/src/Functions/multiIf.cpp +++ b/src/Functions/multiIf.cpp @@ -262,7 +262,7 @@ public: private: static void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments) { - int last_short_circuit_argument_index = checkShirtCircuitArguments(arguments); + int last_short_circuit_argument_index = checkShortCircuitArguments(arguments); if (last_short_circuit_argument_index < 0) return; diff --git a/src/Functions/throwIf.cpp b/src/Functions/throwIf.cpp index d499f1f492f..7533e30c9b9 100644 --- a/src/Functions/throwIf.cpp +++ b/src/Functions/throwIf.cpp @@ -48,36 +48,53 @@ public: const size_t number_of_arguments = arguments.size(); if (number_of_arguments < 1 || number_of_arguments > 2) - throw Exception{"Number of arguments for function " + getName() + " doesn't match: passed " - + toString(number_of_arguments) + ", should be 1 or 2", - ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH}; + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Number of arguments for function {} doesn't match: passed {}, should be 1 or 2", + getName(), + toString(number_of_arguments)); if (!isNativeNumber(arguments[0])) - throw Exception{"Argument for function " + getName() + " must be number", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Argument for function {} must be number", + getName()); if (number_of_arguments > 1 && !isString(arguments[1])) - throw Exception{"Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of argument of function {}", + arguments[1]->getName(), + getName()); return std::make_shared(); } - bool useDefaultImplementationForConstants() const override { return true; } + bool useDefaultImplementationForConstants() const override { return false; } ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override + /** Prevent constant folding for FunctionThrowIf because for short circuit evaluation + * it is unsafe to evaluate this function during DAG analysis. + */ + bool isSuitableForConstantFolding() const override { return false; } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override { + if (input_rows_count == 0) + return result_type->createColumn(); + std::optional custom_message; if (arguments.size() == 2) { - const auto * msg_column = checkAndGetColumnConst(arguments[1].column.get()); - if (!msg_column) - throw Exception{"Second argument for function " + getName() + " must be constant String", ErrorCodes::ILLEGAL_COLUMN}; - custom_message = msg_column->getValue(); + const auto * message_column = checkAndGetColumnConst(arguments[1].column.get()); + if (!message_column) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, + "Second argument for function {} must be constant String", + getName()); + + custom_message = message_column->getValue(); } - const auto * in = arguments.front().column.get(); + auto first_argument_column = arguments.front().column; + const auto * in = first_argument_column.get(); ColumnPtr res; if (!((res = execute(in, custom_message)) @@ -90,7 +107,9 @@ public: || (res = execute(in, custom_message)) || (res = execute(in, custom_message)) || (res = execute(in, custom_message)))) + { throw Exception{"Illegal column " + in->getName() + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN}; + } return res; } @@ -98,15 +117,22 @@ public: template ColumnPtr execute(const IColumn * in_untyped, const std::optional & message) const { - if (const auto in = checkAndGetColumn>(in_untyped)) + const auto * in = checkAndGetColumn>(in_untyped); + + if (!in) + in = checkAndGetColumnConstData>(in_untyped); + + if (in) { const auto & in_data = in->getData(); if (!memoryIsZero(in_data.data(), in_data.size() * sizeof(in_data[0]))) - throw Exception{message.value_or("Value passed to '" + getName() + "' function is non zero"), - ErrorCodes::FUNCTION_THROW_IF_VALUE_IS_NON_ZERO}; + { + throw Exception(ErrorCodes::FUNCTION_THROW_IF_VALUE_IS_NON_ZERO, + message.value_or("Value passed to '" + getName() + "' function is non zero")); + } /// We return non constant to avoid constant folding. - return ColumnUInt8::create(in_data.size(), 0); + return ColumnUInt8::create(in_data.size(), 0); } return nullptr; diff --git a/tests/queries/0_stateless/02152_short_circuit_throw_if.reference b/tests/queries/0_stateless/02152_short_circuit_throw_if.reference new file mode 100644 index 00000000000..aa47d0d46d4 --- /dev/null +++ b/tests/queries/0_stateless/02152_short_circuit_throw_if.reference @@ -0,0 +1,2 @@ +0 +0 diff --git a/tests/queries/0_stateless/02152_short_circuit_throw_if.sql b/tests/queries/0_stateless/02152_short_circuit_throw_if.sql new file mode 100644 index 00000000000..3fdc3cc48c8 --- /dev/null +++ b/tests/queries/0_stateless/02152_short_circuit_throw_if.sql @@ -0,0 +1,2 @@ +SELECT if(1, 0, throwIf(1, 'Executing FALSE branch')); +SELECT if(empty(''), 0, throwIf(1, 'Executing FALSE branch'));