diff --git a/dbms/src/Functions/throwIf.cpp b/dbms/src/Functions/throwIf.cpp index 15584aa26a7..dc4ac4950e8 100644 --- a/dbms/src/Functions/throwIf.cpp +++ b/dbms/src/Functions/throwIf.cpp @@ -1,9 +1,11 @@ #include #include #include +#include #include #include #include +#include namespace DB @@ -13,6 +15,7 @@ namespace ErrorCodes { extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int FUNCTION_THROW_IF_VALUE_IS_NON_ZERO; } @@ -32,46 +35,70 @@ public: return name; } + bool isVariadic() const override { return true; } size_t getNumberOfArguments() const override { - return 1; + return 0; } DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { - if (!isNativeNumber(arguments.front())) + const size_t number_of_arguments = arguments.size(); + + if (number_of_arguments < 1 || number_of_arguments > 2) + throw Exception{"Number of arguments for function " + getName() + " doesn't match: passed " + + toString(number_of_arguments) + ", should be 1 or 2", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH}; + + if (!isNativeNumber(arguments[0])) throw Exception{"Argument for function " + getName() + " must be number", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + if (number_of_arguments > 1 && !isString(arguments[1])) + throw Exception{"Illegal type " + arguments[1]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + + return std::make_shared(); } bool useDefaultImplementationForConstants() const override { return true; } + ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override { + std::optional custom_message; + if (arguments.size() == 2) + { + auto * msg_column = checkAndGetColumnConst(block.getByPosition(arguments[1]).column.get()); + if (!msg_column) + throw Exception{"Second argument for function " + getName() + " must be constant String", ErrorCodes::ILLEGAL_COLUMN}; + custom_message = msg_column->getValue(); + } + const auto in = block.getByPosition(arguments.front()).column.get(); - if ( !execute(block, in, result) - && !execute(block, in, result) - && !execute(block, in, result) - && !execute(block, in, result) - && !execute(block, in, result) - && !execute(block, in, result) - && !execute(block, in, result) - && !execute(block, in, result) - && !execute(block, in, result) - && !execute(block, in, result)) + if ( !execute(block, in, result, custom_message) + && !execute(block, in, result, custom_message) + && !execute(block, in, result, custom_message) + && !execute(block, in, result, custom_message) + && !execute(block, in, result, custom_message) + && !execute(block, in, result, custom_message) + && !execute(block, in, result, custom_message) + && !execute(block, in, result, custom_message) + && !execute(block, in, result, custom_message) + && !execute(block, in, result, custom_message)) throw Exception{"Illegal column " + in->getName() + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN}; } template - bool execute(Block & block, const IColumn * in_untyped, const size_t result) + bool execute(Block & block, const IColumn * in_untyped, const size_t result, const std::optional & message) { if (const auto in = checkAndGetColumn>(in_untyped)) { const auto & in_data = in->getData(); if (!memoryIsZero(in_data.data(), in_data.size() * sizeof(in_data[0]))) - throw Exception("Value passed to 'throwIf' function is non zero", ErrorCodes::FUNCTION_THROW_IF_VALUE_IS_NON_ZERO); + throw Exception{message.value_or("Value passed to '" + getName() + "' function is non zero"), + ErrorCodes::FUNCTION_THROW_IF_VALUE_IS_NON_ZERO}; /// We return non constant to avoid constant folding. block.getByPosition(result).column = ColumnUInt8::create(in_data.size(), 0); diff --git a/dbms/tests/queries/0_stateless/00602_throw_if.reference b/dbms/tests/queries/0_stateless/00602_throw_if.reference index d0752a77fc7..ad5aaee89a8 100644 --- a/dbms/tests/queries/0_stateless/00602_throw_if.reference +++ b/dbms/tests/queries/0_stateless/00602_throw_if.reference @@ -1,2 +1,3 @@ 1 +1 1000000 diff --git a/dbms/tests/queries/0_stateless/00602_throw_if.sh b/dbms/tests/queries/0_stateless/00602_throw_if.sh index 8dae5033978..69039891bd2 100755 --- a/dbms/tests/queries/0_stateless/00602_throw_if.sh +++ b/dbms/tests/queries/0_stateless/00602_throw_if.sh @@ -3,7 +3,9 @@ CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) . $CURDIR/../shell_config.sh -exception_pattern="Value passed to 'throwIf' function is non zero" +default_exception_message="Value passed to 'throwIf' function is non zero" +custom_exception_message="Number equals 1000000" -${CLICKHOUSE_CLIENT} --server_logs_file /dev/null --query="SELECT throwIf(number = 1000000) FROM system.numbers" 2>&1 | grep -cF "$exception_pattern" +${CLICKHOUSE_CLIENT} --server_logs_file /dev/null --query="SELECT throwIf(number = 1000000) FROM system.numbers" 2>&1 | grep -cF "$default_exception_message" +${CLICKHOUSE_CLIENT} --server_logs_file /dev/null --query="SELECT throwIf(number = 1000000, '$custom_exception_message') FROM system.numbers" 2>&1 | grep -cF "$custom_exception_message" ${CLICKHOUSE_CLIENT} --server_logs_file /dev/null --query="SELECT sum(x = 0) FROM (SELECT throwIf(number = 1000000) AS x FROM numbers(1000000))" 2>&1