Remove weird special case in if() return type inference, to match multiIf()

This commit is contained in:
Michael Kolupaev 2024-03-30 00:35:34 +00:00
parent 059f1abcf6
commit c1ea1726b4
2 changed files with 10 additions and 10 deletions

View File

@ -147,7 +147,7 @@ private:
continue; continue;
throw Exception(ErrorCodes::LOGICAL_ERROR, throw Exception(ErrorCodes::LOGICAL_ERROR,
"Function {} expects {} argument to have {} type but receives {} after running {} pass", "Function {} expects argument {} to have {} type but receives {} after running {} pass",
function->toAST()->formatForErrorMessage(), function->toAST()->formatForErrorMessage(),
i + 1, i + 1,
expected_argument_type->getName(), expected_argument_type->getName(),

View File

@ -1278,9 +1278,8 @@ public:
/// Get result types by argument types. If the function does not apply to these arguments, throw an exception. /// Get result types by argument types. If the function does not apply to these arguments, throw an exception.
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{ {
if (arguments[0]->onlyNull()) if (!arguments[0]->onlyNull())
return arguments[2]; {
if (arguments[0]->isNullable()) if (arguments[0]->isNullable())
return getReturnTypeImpl({ return getReturnTypeImpl({
removeNullable(arguments[0]), arguments[1], arguments[2]}); removeNullable(arguments[0]), arguments[1], arguments[2]});
@ -1288,6 +1287,7 @@ public:
if (!WhichDataType(arguments[0]).isUInt8()) if (!WhichDataType(arguments[0]).isUInt8())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of first argument (condition) of function if. " throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of first argument (condition) of function if. "
"Must be UInt8.", arguments[0]->getName()); "Must be UInt8.", arguments[0]->getName());
}
if (use_variant_when_no_common_type) if (use_variant_when_no_common_type)
return getLeastSupertypeOrVariant(DataTypes{arguments[1], arguments[2]}); return getLeastSupertypeOrVariant(DataTypes{arguments[1], arguments[2]});