Merge pull request #32973 from kitaisreal/short-circuit-throw-if-support

Short circuit evaluation function throwIf support
This commit is contained in:
Maksim Kita 2021-12-20 19:50:53 +03:00 committed by GitHub
commit daa23a2827
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 58 additions and 23 deletions

View File

@ -24,7 +24,12 @@ class ColumnFunction final : public COWHelper<IColumn, ColumnFunction>
private:
friend class COWHelper<IColumn, ColumnFunction>;
ColumnFunction(size_t size, FunctionBasePtr function_, const ColumnsWithTypeAndName & columns_to_capture, bool is_short_circuit_argument_ = false, bool is_function_compiled_ = false);
ColumnFunction(
size_t size,
FunctionBasePtr function_,
const ColumnsWithTypeAndName & columns_to_capture,
bool is_short_circuit_argument_ = false,
bool is_function_compiled_ = false);
public:
const char * getFamilyName() const override { return "Function"; }

View File

@ -293,7 +293,7 @@ void executeColumnIfNeeded(ColumnWithTypeAndName & column, bool empty)
column.column = column_function->getResultType()->createColumn();
}
int checkShirtCircuitArguments(const ColumnsWithTypeAndName & arguments)
int checkShortCircuitArguments(const ColumnsWithTypeAndName & arguments)
{
int last_short_circuit_argument_index = -1;
for (size_t i = 0; i != arguments.size(); ++i)

View File

@ -66,7 +66,7 @@ void executeColumnIfNeeded(ColumnWithTypeAndName & column, bool empty = false);
/// Check if arguments contain lazy executed argument. If contain, return index of the last one,
/// otherwise return -1.
int checkShirtCircuitArguments(const ColumnsWithTypeAndName & arguments);
int checkShortCircuitArguments(const ColumnsWithTypeAndName & arguments);
void copyMask(const PaddedPODArray<UInt8> & from, PaddedPODArray<UInt8> & to);

View File

@ -609,7 +609,7 @@ ColumnPtr FunctionAnyArityLogical<Impl, Name>::executeImpl(
ColumnsWithTypeAndName arguments = std::move(args);
/// Special implementation for short-circuit arguments.
if (checkShirtCircuitArguments(arguments) != -1)
if (checkShortCircuitArguments(arguments) != -1)
return executeShortCircuit(arguments, result_type);
ColumnRawPtrs args_in;

View File

@ -969,7 +969,7 @@ private:
static void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments)
{
int last_short_circuit_argument_index = checkShirtCircuitArguments(arguments);
int last_short_circuit_argument_index = checkShortCircuitArguments(arguments);
if (last_short_circuit_argument_index == -1)
return;

View File

@ -262,7 +262,7 @@ public:
private:
static void executeShortCircuitArguments(ColumnsWithTypeAndName & arguments)
{
int last_short_circuit_argument_index = checkShirtCircuitArguments(arguments);
int last_short_circuit_argument_index = checkShortCircuitArguments(arguments);
if (last_short_circuit_argument_index < 0)
return;

View File

@ -48,36 +48,53 @@ public:
const size_t number_of_arguments = arguments.size();
if (number_of_arguments < 1 || number_of_arguments > 2)
throw Exception{"Number of arguments for function " + getName() + " doesn't match: passed "
+ toString(number_of_arguments) + ", should be 1 or 2",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Number of arguments for function {} doesn't match: passed {}, should be 1 or 2",
getName(),
toString(number_of_arguments));
if (!isNativeNumber(arguments[0]))
throw Exception{"Argument for function " + getName() + " must be number", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Argument for function {} must be number",
getName());
if (number_of_arguments > 1 && !isString(arguments[1]))
throw Exception{"Illegal type " + arguments[1]->getName() + " of argument of function " + getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument of function {}",
arguments[1]->getName(),
getName());
return std::make_shared<DataTypeUInt8>();
}
bool useDefaultImplementationForConstants() const override { return true; }
bool useDefaultImplementationForConstants() const override { return false; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override
/** Prevent constant folding for FunctionThrowIf because for short circuit evaluation
* it is unsafe to evaluate this function during DAG analysis.
*/
bool isSuitableForConstantFolding() const override { return false; }
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
{
if (input_rows_count == 0)
return result_type->createColumn();
std::optional<String> custom_message;
if (arguments.size() == 2)
{
const auto * msg_column = checkAndGetColumnConst<ColumnString>(arguments[1].column.get());
if (!msg_column)
throw Exception{"Second argument for function " + getName() + " must be constant String", ErrorCodes::ILLEGAL_COLUMN};
custom_message = msg_column->getValue<String>();
const auto * message_column = checkAndGetColumnConst<ColumnString>(arguments[1].column.get());
if (!message_column)
throw Exception(ErrorCodes::ILLEGAL_COLUMN,
"Second argument for function {} must be constant String",
getName());
custom_message = message_column->getValue<String>();
}
const auto * in = arguments.front().column.get();
auto first_argument_column = arguments.front().column;
const auto * in = first_argument_column.get();
ColumnPtr res;
if (!((res = execute<UInt8>(in, custom_message))
@ -90,7 +107,9 @@ public:
|| (res = execute<Int64>(in, custom_message))
|| (res = execute<Float32>(in, custom_message))
|| (res = execute<Float64>(in, custom_message))))
{
throw Exception{"Illegal column " + in->getName() + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN};
}
return res;
}
@ -98,15 +117,22 @@ public:
template <typename T>
ColumnPtr execute(const IColumn * in_untyped, const std::optional<String> & message) const
{
if (const auto in = checkAndGetColumn<ColumnVector<T>>(in_untyped))
const auto * in = checkAndGetColumn<ColumnVector<T>>(in_untyped);
if (!in)
in = checkAndGetColumnConstData<ColumnVector<T>>(in_untyped);
if (in)
{
const auto & in_data = in->getData();
if (!memoryIsZero(in_data.data(), in_data.size() * sizeof(in_data[0])))
throw Exception{message.value_or("Value passed to '" + getName() + "' function is non zero"),
ErrorCodes::FUNCTION_THROW_IF_VALUE_IS_NON_ZERO};
{
throw Exception(ErrorCodes::FUNCTION_THROW_IF_VALUE_IS_NON_ZERO,
message.value_or("Value passed to '" + getName() + "' function is non zero"));
}
/// We return non constant to avoid constant folding.
return ColumnUInt8::create(in_data.size(), 0);
return ColumnUInt8::create(in_data.size(), 0);
}
return nullptr;

View File

@ -0,0 +1,2 @@
0
0

View File

@ -0,0 +1,2 @@
SELECT if(1, 0, throwIf(1, 'Executing FALSE branch'));
SELECT if(empty(''), 0, throwIf(1, 'Executing FALSE branch'));