Add fixes, add new mode to getLeastSupertype and use it in if/multiIf

This commit is contained in:
avogar 2024-01-11 18:44:05 +00:00
parent f05d89bc2b
commit 9e639df12e
6 changed files with 79 additions and 32 deletions

View File

@ -77,7 +77,7 @@ static ReturnType addElementSafe(size_t num_elems, IColumn & column, F && impl)
auto & element_column = extractElementColumn(column, i);
if (element_column.size() > old_size)
{
chassert(old_size - element_column.size() == 1);
chassert(element_column.size() - old_size == 1);
element_column.popBack(1);
}
}

View File

@ -18,6 +18,7 @@
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypeFactory.h>
#include <DataTypes/DataTypeVariant.h>
namespace DB
@ -58,6 +59,25 @@ DataTypePtr throwOrReturn(const DataTypes & types, std::string_view message_suff
if constexpr (on_error == LeastSupertypeOnError::String)
return std::make_shared<DataTypeString>();
if constexpr (on_error == LeastSupertypeOnError::Variant && std::is_same_v<DataTypes, std::vector<DataTypePtr>>)
{
DataTypes variants;
for (const auto & type : types)
{
if (isVariant(type))
{
const DataTypes & nested_variants = assert_cast<const DataTypeVariant &>(*type).getVariants();
variants.insert(variants.end(), nested_variants.begin(), nested_variants.end());
}
else
{
variants.push_back(removeNullableOrLowCardinalityNullable(type));
}
}
return std::make_shared<DataTypeVariant>(variants);
}
if constexpr (on_error == LeastSupertypeOnError::Null)
return nullptr;
@ -67,8 +87,8 @@ DataTypePtr throwOrReturn(const DataTypes & types, std::string_view message_suff
throw Exception(error_code, "There is no supertype for types {} {}", getExceptionMessagePrefix(types), message_suffix);
}
template <LeastSupertypeOnError on_error>
DataTypePtr getNumericType(const TypeIndexSet & types)
template <typename ThrowOrReturnFunc>
DataTypePtr getNumericType(const TypeIndexSet & types, ThrowOrReturnFunc throwOrReturnFunc)
{
bool all_numbers = true;
@ -119,7 +139,7 @@ DataTypePtr getNumericType(const TypeIndexSet & types)
if (max_bits_of_signed_integer || max_bits_of_unsigned_integer || max_mantissa_bits_of_floating)
{
if (!all_numbers)
return throwOrReturn<on_error>(types, "because some of them are numbers and some of them are not", ErrorCodes::NO_COMMON_TYPE);
return throwOrReturnFunc(types, "because some of them are numbers and some of them are not", ErrorCodes::NO_COMMON_TYPE);
/// If there are signed and unsigned types of same bit-width, the result must be signed number with at least one more bit.
/// Example, common of Int32, UInt32 = Int64.
@ -134,7 +154,7 @@ DataTypePtr getNumericType(const TypeIndexSet & types)
if (min_bit_width_of_integer != 64)
++min_bit_width_of_integer;
else
return throwOrReturn<on_error>(types,
return throwOrReturnFunc(types,
"because some of them are signed integers and some are unsigned integers,"
" but there is no signed integer type, that can exactly represent all required unsigned integer values",
ErrorCodes::NO_COMMON_TYPE);
@ -149,7 +169,7 @@ DataTypePtr getNumericType(const TypeIndexSet & types)
else if (min_mantissa_bits <= 53)
return std::make_shared<DataTypeFloat64>();
else
return throwOrReturn<on_error>(types,
return throwOrReturnFunc(types,
" because some of them are integers and some are floating point,"
" but there is no floating point type, that can exactly represent all required integers", ErrorCodes::NO_COMMON_TYPE);
}
@ -170,7 +190,7 @@ DataTypePtr getNumericType(const TypeIndexSet & types)
else if (min_bit_width_of_integer <= 256)
return std::make_shared<DataTypeInt256>();
else
return throwOrReturn<on_error>(types,
return throwOrReturnFunc(types,
" because some of them are signed integers and some are unsigned integers,"
" but there is no signed integer type, that can exactly represent all required unsigned integer values", ErrorCodes::NO_COMMON_TYPE);
}
@ -190,7 +210,7 @@ DataTypePtr getNumericType(const TypeIndexSet & types)
else if (min_bit_width_of_integer <= 256)
return std::make_shared<DataTypeUInt256>();
else
return throwOrReturn<on_error>(types,
return throwOrReturnFunc(types,
" but as all data types are unsigned integers, we must have found maximum unsigned integer type", ErrorCodes::NO_COMMON_TYPE);
}
}
@ -382,7 +402,18 @@ DataTypePtr getLeastSupertype(const DataTypes & types)
if (!all_maps)
return throwOrReturn<on_error>(types, "because some of them are Maps and some of them are not", ErrorCodes::NO_COMMON_TYPE);
auto keys_common_type = getLeastSupertype<on_error>(key_types);
DataTypePtr keys_common_type;
if constexpr (on_error == LeastSupertypeOnError::Variant)
{
keys_common_type = getLeastSupertype<LeastSupertypeOnError::Null>(key_types);
if (!keys_common_type)
return throwOrReturn<on_error>(types, "", ErrorCodes::NO_COMMON_TYPE);
}
else
{
keys_common_type = getLeastSupertype<on_error>(key_types);
}
auto values_common_type = getLeastSupertype<on_error>(value_types);
/// When on_error == LeastSupertypeOnError::Null and we cannot get least supertype for keys or values,
/// keys_common_type or values_common_type will be nullptr, we should return nullptr in this case.
@ -423,7 +454,18 @@ DataTypePtr getLeastSupertype(const DataTypes & types)
return getLeastSupertype<on_error>(nested_types);
else
{
auto nested_type = getLeastSupertype<on_error>(nested_types);
DataTypePtr nested_type;
if constexpr (on_error == LeastSupertypeOnError::Variant)
{
nested_type = getLeastSupertype<LeastSupertypeOnError::Null>(nested_types);
if (!nested_type)
return throwOrReturn<on_error>(types, "", ErrorCodes::NO_COMMON_TYPE);
}
else
{
nested_type = getLeastSupertype<on_error>(nested_types);
}
/// When on_error == LeastSupertypeOnError::Null and we cannot get least supertype,
/// nested_type will be nullptr, we should return nullptr in this case.
if (!nested_type)
@ -456,6 +498,8 @@ DataTypePtr getLeastSupertype(const DataTypes & types)
if (have_nullable)
{
auto nested_type = getLeastSupertype<on_error>(nested_types);
if (isVariant(nested_type))
return nested_type;
/// When on_error == LeastSupertypeOnError::Null and we cannot get least supertype,
/// nested_type will be nullptr, we should return nullptr in this case.
if (!nested_type)
@ -623,7 +667,8 @@ DataTypePtr getLeastSupertype(const DataTypes & types)
{
/// First, if we have signed integers, try to convert all UInt64 to Int64 if possible.
convertUInt64toInt64IfPossible(types, type_ids);
auto numeric_type = getNumericType<on_error>(type_ids);
auto throw_or_return = [&](const TypeIndexSet &, std::string_view message_suffix, int error_code){ return throwOrReturn<on_error>(types, message_suffix, error_code); };
auto numeric_type = getNumericType(type_ids, throw_or_return);
if (numeric_type)
return numeric_type;
}
@ -637,6 +682,11 @@ DataTypePtr getLeastSupertypeOrString(const DataTypes & types)
return getLeastSupertype<LeastSupertypeOnError::String>(types);
}
DataTypePtr getLeastSupertypeOrVariant(const DataTypes & types)
{
return getLeastSupertype<LeastSupertypeOnError::Variant>(types);
}
DataTypePtr tryGetLeastSupertype(const DataTypes & types)
{
return getLeastSupertype<LeastSupertypeOnError::Null>(types);
@ -676,7 +726,8 @@ DataTypePtr getLeastSupertype(const TypeIndexSet & types)
return std::make_shared<DataTypeString>();
}
auto numeric_type = getNumericType<on_error>(types);
auto throw_or_return = [](const TypeIndexSet & type_ids, std::string_view message_suffix, int error_code){ return throwOrReturn<on_error>(type_ids, message_suffix, error_code); };
auto numeric_type = getNumericType(types, throw_or_return);
if (numeric_type)
return numeric_type;

View File

@ -8,6 +8,7 @@ enum class LeastSupertypeOnError
{
Throw,
String,
Variant,
Null,
};
@ -24,6 +25,17 @@ DataTypePtr getLeastSupertype(const DataTypes & types);
/// All types can be casted to String, because they can be serialized to String.
DataTypePtr getLeastSupertypeOrString(const DataTypes & types);
/// Same as getLeastSupertype but in case when there is no supertype for some types
/// it uses Variant of these types as a supertype. Any type can be casted to a Variant
/// that contains this type.
/// As nested Variants are not allowed, if one of the types is Variant, it's variants
/// are used in the resulting Variant.
/// Examples:
/// (UInt64, String) -> Variant(UInt64, String)
/// (Array(UInt64), Array(String)) -> Array(Variant(UInt64, String))
/// (Variant(UInt64, String), Array(UInt32)) -> Variant(UInt64, String, Array(UInt32))
DataTypePtr getLeastSupertypeOrVariant(const DataTypes & types);
/// Same as above but return nullptr instead of throwing exception.
DataTypePtr tryGetLeastSupertype(const DataTypes & types);

View File

@ -688,15 +688,9 @@ private:
DataTypePtr common_type;
if (use_variant_when_no_common_type)
{
common_type = tryGetLeastSupertype(DataTypes{arg1.type, arg2.type});
if (!common_type)
common_type = std::make_shared<DataTypeVariant>(DataTypes{removeNullableOrLowCardinalityNullable(arg1.type), removeNullableOrLowCardinalityNullable(arg2.type)});
}
common_type = getLeastSupertypeOrVariant(DataTypes{arg1.type, arg2.type});
else
{
common_type = getLeastSupertype(DataTypes{arg1.type, arg2.type});
}
ColumnPtr col_then = castColumn(arg1, common_type);
ColumnPtr col_else = castColumn(arg2, common_type);
@ -1118,11 +1112,7 @@ public:
"Must be UInt8.", arguments[0]->getName());
if (use_variant_when_no_common_type)
{
if (auto res = tryGetLeastSupertype(DataTypes{arguments[1], arguments[2]}))
return res;
return std::make_shared<DataTypeVariant>(DataTypes{removeNullableOrLowCardinalityNullable(arguments[1]), removeNullableOrLowCardinalityNullable(arguments[2])});
}
return getLeastSupertypeOrVariant(DataTypes{arguments[1], arguments[2]});
return getLeastSupertype(DataTypes{arguments[1], arguments[2]});
}

View File

@ -119,13 +119,7 @@ public:
});
if (context->getSettingsRef().allow_experimental_variant_type && context->getSettingsRef().use_variant_when_no_common_type_in_if)
{
if (auto res = tryGetLeastSupertype(types_of_branches))
return res;
for (auto & type : types_of_branches)
type = removeNullableOrLowCardinalityNullable(type);
return std::make_shared<DataTypeVariant>(types_of_branches);
}
return getLeastSupertypeOrVariant(types_of_branches);
return getLeastSupertype(types_of_branches);
}

View File

@ -505,7 +505,7 @@ String
(NULL,NULL),('string',NULL),(-1,-1),(0,0),(10000000000,NULL)(NULL,NULL),('string',NULL),(-1,NULL),(0,0),(10000000000,NULL)(NULL,NULL),('string',NULL),(-1,-1),(0,0),(10000000000,NULL)(NULL,NULL),('string',NULL),(-1,NULL),(0,0),(10000000000,NULL)(NULL,NULL),('string',NULL),(-1,-1),(0,0),(10000000000,NULL)(NULL,NULL),('string',NULL),(-1,NULL),(0,0),(10000000000,NULL)(NULL,NULL),('string',NULL),(-1,-1),(0,0),(10000000000000000000000,NULL)(NULL,NULL),('string',NULL),(-1,NULL),(0,0),(10000000000000000000000,NULL)(NULL,NULL),('string',NULL),(-1,-1),(0,0)(NULL,NULL),('string',NULL),(-1,NULL),(0,0)Floats
(NULL,NULL),('string',NULL),(42.42,42.42)(NULL,NULL),('string',NULL),(42.42,42.42)Decimals
(NULL,NULL),('string',NULL),(42.42,42.42)(NULL,NULL),('string',NULL),(42.42,42.42)(NULL,NULL),('string',NULL),(42.42,42.42)(NULL,NULL),('string',NULL),(42.42,42.42)Dates and DateTimes
(NULL,NULL),('string',NULL),('1970-01-01 00:00:00.000',NULL),('2020-01-01','2020-01-01'),('2020-01-01 00:00:00.999',NULL)(NULL,NULL),('string',NULL),('1970-01-01 00:00:00.000',NULL),('1900-01-01','1900-01-01'),('2020-01-01 00:00:00.999',NULL)(NULL,NULL),('string',NULL),('1970-01-01 00:00:00.000',NULL),('2020-01-01 00:00:00','2020-01-01 00:00:00'),('2020-01-01 00:00:00.999',NULL)(NULL,NULL),('string',NULL),('1970-01-01 00:00:00.000','1970-01-01 00:00:00.000'),('2020-01-01 00:00:00.999',NULL),('2020-01-01 00:00:00.999999999 ABC',NULL)UUID
(NULL,NULL),('string',NULL),('2020-01-d1',NULL),('2020-01-01','2020-01-01'),('2020-01-01 00:00:00.999',NULL)(NULL,NULL),('string',NULL),('2020-01-d1',NULL),('1900-01-01','1900-01-01'),('2020-01-01 00:00:00.999',NULL)(NULL,NULL),('string',NULL),('2020-01-d1',NULL),('2020-01-01 00:00:00','2020-01-01 00:00:00'),('2020-01-01 00:00:00.999',NULL)(NULL,NULL),('string',NULL),('2020-01-d1',NULL),('2020-01-01 00:00:00.999','2020-01-01 00:00:00.999'),('2020-01-01 00:00:00.999999999 ABC',NULL)UUID
(NULL,NULL),('string',NULL),('c8619cca-0caa-445e-ae76-1d4f6e0b3927','c8619cca-0caa-445e-ae76-1d4f6e0b3927'),('c8619cca-0caa-445e-ae76-1d4f6e0b3927AAA',NULL)IPv4
(NULL,NULL),('string',NULL),('127.0.0.1','127.0.0.1'),('127.0.0.1AAA',NULL)IPv6
(NULL,NULL),('string',NULL),('2001:db8:85a3::8a2e:370:7334','2001:db8:85a3::8a2e:370:7334'),('2001:0db8:85a3:0000:0000:8a2e:0370:7334AAA',NULL)Enum