Try using input_rows_count as validation

This commit is contained in:
Raúl Marín 2024-05-21 18:37:59 +02:00
parent a53bd5793a
commit d4b723bcbe

View File

@ -299,32 +299,27 @@ bool isDecimalOrNullableDecimal(const DataTypePtr & type)
return isDecimal(assert_cast<const DataTypeNullable *>(type.get())->getNestedType());
}
void checkFunctionArgumentSizes(const ColumnsWithTypeAndName & arguments [[maybe_unused]], size_t input_rows_count [[maybe_unused]])
/// Note that, for historical reasons, most of the functions use the first argument size to determine which is the
/// size of all the columns. When short circuit optimization was introduced, `input_rows_count` was also added for
/// all functions, but many have not been adjusted
void checkFunctionArgumentSizes(const ColumnsWithTypeAndName & arguments, size_t input_rows_count)
{
if (!arguments.empty())
for (size_t i = 0; i < arguments.size(); i++)
{
/// Note that ideally this check should be simpler and we should check that all columns should either be const
/// or have exactly size input_rows_count
/// For historical reasons this is not the case, and many functions rely on the size of the first column
/// to decide which is the size of all the inputs
/// Hopefully this will be slowly improved in the future
if (isColumnConst(*arguments[i].column))
continue;
if (!isColumnConst(*arguments[0].column))
{
size_t expected_size = arguments[0].column->size();
size_t current_size = arguments[i].column->size();
/// TODO: Function name in the message?
for (size_t i = 1; i < arguments.size(); i++)
if (!isColumnConst(*arguments[i].column) && arguments[i].column->size() != expected_size)
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"Expected the argument nº#{} ('{}' of type {}) to have {} rows, but it has {}",
i + 1,
arguments[i].name,
arguments[i].type->getName(),
expected_size,
arguments[i].column->size());
}
if (current_size != input_rows_count)
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"Expected the argument nº#{} ('{}' of type {}) to have {} rows, but it has {}",
i + 1,
arguments[i].name,
arguments[i].type->getName(),
input_rows_count,
current_size);
}
}
}