Merge pull request #54553 from Avogar/better-types-inference

Better integer types inference for Int64/UInt64 fields
This commit is contained in:
Alexey Milovidov 2023-09-24 02:07:48 +03:00 committed by GitHub
commit 776c6adfe6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 91 additions and 59 deletions

View File

@ -35,6 +35,12 @@ static DataTypePtr createNumericDataType(const ASTPtr & arguments)
return std::make_shared<DataTypeNumber<T>>();
}
bool isUInt64ThatCanBeInt64(const DataTypePtr & type)
{
const DataTypeUInt64 * uint64_type = typeid_cast<const DataTypeUInt64 *>(type.get());
return uint64_type && uint64_type->canUnsignedBeSigned();
}
void registerDataTypeNumbers(DataTypeFactory & factory)
{

View File

@ -9,10 +9,17 @@
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
template <typename T>
class DataTypeNumber final : public DataTypeNumberBase<T>
{
public:
DataTypeNumber() = default;
bool equals(const IDataType & rhs) const override { return typeid(rhs) == typeid(*this); }
bool canBeUsedAsVersion() const override { return true; }
@ -32,6 +39,20 @@ public:
{
return std::make_shared<SerializationNumber<T>>();
}
/// Special constructor for unsigned integers that can also fit into signed integer.
/// It's used for better type inference from fields.
/// See getLeastSupertype.cpp::convertUInt64toInt64IfPossible and FieldToDataType.cpp
explicit DataTypeNumber(bool unsigned_can_be_signed_) : DataTypeNumberBase<T>(), unsigned_can_be_signed(unsigned_can_be_signed_)
{
if constexpr (std::is_signed_v<T>)
throw Exception(ErrorCodes::LOGICAL_ERROR, "DataTypeNumber constructor with bool argument should not be used with signed integers");
}
bool canUnsignedBeSigned() const { return unsigned_can_be_signed; }
private:
bool unsigned_can_be_signed = false;
};
using DataTypeUInt8 = DataTypeNumber<UInt8>;
@ -50,4 +71,6 @@ using DataTypeInt128 = DataTypeNumber<Int128>;
using DataTypeUInt256 = DataTypeNumber<UInt256>;
using DataTypeInt256 = DataTypeNumber<Int256>;
bool isUInt64ThatCanBeInt64(const DataTypePtr & type);
}

View File

@ -36,6 +36,7 @@ DataTypePtr FieldToDataType<on_error>::operator() (const UInt64 & x) const
if (x <= std::numeric_limits<UInt8>::max()) return std::make_shared<DataTypeUInt8>();
if (x <= std::numeric_limits<UInt16>::max()) return std::make_shared<DataTypeUInt16>();
if (x <= std::numeric_limits<UInt32>::max()) return std::make_shared<DataTypeUInt32>();
if (x <= std::numeric_limits<Int64>::max()) return std::make_shared<DataTypeUInt64>(/*unsigned_can_be_signed=*/true);
return std::make_shared<DataTypeUInt64>();
}
@ -136,17 +137,8 @@ DataTypePtr FieldToDataType<on_error>::operator() (const Array & x) const
DataTypes element_types;
element_types.reserve(x.size());
bool has_signed_int = false;
bool uint64_convert_possible = true;
for (const Field & elem : x)
{
DataTypePtr type = applyVisitor(*this, elem);
element_types.emplace_back(type);
checkUInt64ToIn64Conversion(has_signed_int, uint64_convert_possible, type, elem);
}
if (has_signed_int && uint64_convert_possible)
convertUInt64ToInt64IfPossible(element_types);
element_types.emplace_back(applyVisitor(*this, elem));
return std::make_shared<DataTypeArray>(getLeastSupertype<on_error>(element_types));
}
@ -174,28 +166,14 @@ DataTypePtr FieldToDataType<on_error>::operator() (const Map & map) const
key_types.reserve(map.size());
value_types.reserve(map.size());
bool k_has_signed_int = false;
bool k_uint64_convert_possible = true;
bool v_has_signed_int = false;
bool v_uint64_convert_possible = true;
for (const auto & elem : map)
{
const auto & tuple = elem.safeGet<const Tuple &>();
assert(tuple.size() == 2);
DataTypePtr k_type = applyVisitor(*this, tuple[0]);
key_types.push_back(k_type);
checkUInt64ToIn64Conversion(k_has_signed_int, k_uint64_convert_possible, k_type, tuple[0]);
DataTypePtr v_type = applyVisitor(*this, tuple[1]);
value_types.push_back(v_type);
checkUInt64ToIn64Conversion(v_has_signed_int, v_uint64_convert_possible, v_type, tuple[1]);
key_types.push_back(applyVisitor(*this, tuple[0]));
value_types.push_back(applyVisitor(*this, tuple[1]));
}
if (k_has_signed_int && k_uint64_convert_possible)
convertUInt64ToInt64IfPossible(key_types);
if (v_has_signed_int && v_uint64_convert_possible)
convertUInt64ToInt64IfPossible(value_types);
return std::make_shared<DataTypeMap>(
getLeastSupertype<on_error>(key_types),
getLeastSupertype<on_error>(value_types));
@ -227,28 +205,6 @@ DataTypePtr FieldToDataType<on_error>::operator()(const bool &) const
return DataTypeFactory::instance().get("Bool");
}
template <LeastSupertypeOnError on_error>
void FieldToDataType<on_error>::checkUInt64ToIn64Conversion(bool & has_signed_int, bool & uint64_convert_possible, const DataTypePtr & type, const Field & elem) const
{
if (uint64_convert_possible)
{
bool is_native_int = WhichDataType(type).isNativeInt();
if (is_native_int)
has_signed_int |= is_native_int;
else if (type->getTypeId() == TypeIndex::UInt64)
uint64_convert_possible &= (elem.template get<UInt64>() <= std::numeric_limits<Int64>::max());
}
}
template <LeastSupertypeOnError on_error>
void FieldToDataType<on_error>::convertUInt64ToInt64IfPossible(DataTypes & data_types) const
{
for (auto& type : data_types)
if (type->getTypeId() == TypeIndex::UInt64)
type = std::make_shared<DataTypeInt64>();
}
template class FieldToDataType<LeastSupertypeOnError::Throw>;
template class FieldToDataType<LeastSupertypeOnError::String>;
template class FieldToDataType<LeastSupertypeOnError::Null>;

View File

@ -45,16 +45,6 @@ public:
DataTypePtr operator() (const UInt256 & x) const;
DataTypePtr operator() (const Int256 & x) const;
DataTypePtr operator() (const bool & x) const;
private:
// The conditions for converting UInt64 to Int64 are:
// 1. The existence of Int.
// 2. The existence of UInt64, and the UInt64 value must be <= Int64.max.
void checkUInt64ToIn64Conversion(bool& has_signed_int, bool& uint64_convert_possible, const DataTypePtr & type, const Field & elem) const;
// Convert the UInt64 type to Int64 in order to cover other signed_integer types
// and obtain the least super type of all ints.
void convertUInt64ToInt64IfPossible(DataTypes & data_types) const;
};
}

View File

@ -198,6 +198,35 @@ DataTypePtr getNumericType(const TypeIndexSet & types)
return {};
}
/// Check if we can convert UInt64 to Int64 to avoid error "There is no supertype for types UInt64, Int64"
/// during inferring field types.
/// Example:
/// [-3236599669630092879, 5607475129431807682]
/// First field is inferred as Int64, but second one as UInt64, although it also can be Int64.
/// We don't support Int128 as supertype for Int64 and UInt64, because Int128 is inefficient.
/// But in this case the result type can be inferred as Array(Int64).
void convertUInt64toInt64IfPossible(const DataTypes & types, TypeIndexSet & types_set)
{
/// Check if we have UInt64 and at least one Integer type.
if (!types_set.contains(TypeIndex::UInt64)
|| (!types_set.contains(TypeIndex::Int8) && !types_set.contains(TypeIndex::Int16) && !types_set.contains(TypeIndex::Int32)
&& !types_set.contains(TypeIndex::Int64)))
return;
bool all_uint64_can_be_int64 = true;
for (const auto & type : types)
{
if (const auto * uint64_type = typeid_cast<const DataTypeUInt64 *>(type.get()))
all_uint64_can_be_int64 &= uint64_type->canUnsignedBeSigned();
}
if (all_uint64_can_be_int64)
{
types_set.erase(TypeIndex::UInt64);
types_set.insert(TypeIndex::Int64);
}
}
}
template <LeastSupertypeOnError on_error>
@ -592,6 +621,8 @@ DataTypePtr getLeastSupertype(const DataTypes & types)
/// For numeric types, the most complicated part.
{
/// First, if we have signed integers, try to convert all UInt64 to Int64 if possible.
convertUInt64toInt64IfPossible(types, type_ids);
auto numeric_type = getNumericType<on_error>(type_ids);
if (numeric_type)
return numeric_type;

View File

@ -1105,6 +1105,14 @@ public:
if (const auto * right_array = checkAndGetDataType<DataTypeArray>(arg_else.type.get()))
right_id = right_array->getNestedType()->getTypeId();
/// Special case when one column is Integer and another is UInt64 that can be actually Int64.
/// The result type for this case is Int64 and we need to change UInt64 type to Int64
/// so the NumberTraits::ResultOfIf will return Int64 instead if Int128.
if (isNativeInteger(arg_then.type) && isUInt64ThatCanBeInt64(arg_else.type))
right_id = TypeIndex::Int64;
else if (isNativeInteger(arg_else.type) && isUInt64ThatCanBeInt64(arg_then.type))
left_id = TypeIndex::Int64;
if (!(callOnBasicTypes<true, true, true, false>(left_id, right_id, call)
|| (res = executeTyped<UUID, UUID>(cond_col, arguments, result_type, input_rows_count))
|| (res = executeString(cond_col, arguments, result_type))

View File

@ -6,6 +6,6 @@ SELECT ifNotFinite(nan, 2);
SELECT ifNotFinite(-1 / 0, 2);
SELECT ifNotFinite(log(0), NULL);
SELECT ifNotFinite(sqrt(-1), -42);
SELECT ifNotFinite(1234567890123456789, -1234567890123456789); -- { serverError 386 }
SELECT ifNotFinite(12345678901234567890, -12345678901234567890); -- { serverError 386 }
SELECT ifNotFinite(NULL, 1);

View File

@ -1,2 +1,11 @@
[-4741124612489978151,-3236599669630092879,5607475129431807682]
[100,-100,5607475129431807682,5607475129431807683]
[[-4741124612489978151],[-3236599669630092879,5607475129431807682]]
[[-4741124612489978151,-3236599669630092879],[5607475129431807682]]
[(-4741124612489978151,1),(-3236599669630092879,2),(560747512943180768,3)]
[-4741124612489978151,1,-3236599669630092879,2,560747512943180768,3]
{-4741124612489978151:1,-3236599669630092879:2,5607475129431807682:3}
[{-4741124612489978151:1,-3236599669630092879:2,5607475129431807682:3},{-1:1}]
{1:-4741124612489978151,2:-3236599669630092879,3:5607475129431807682}
[{1:-4741124612489978151,2:-3236599669630092879,3:5607475129431807682},{-1:1}]
-1234567890123456789

View File

@ -1,2 +1,11 @@
select [-4741124612489978151, -3236599669630092879, 5607475129431807682];
select [100, -100, 5607475129431807682, 5607475129431807683];
select [[-4741124612489978151], [-3236599669630092879, 5607475129431807682]];
select [[-4741124612489978151, -3236599669630092879], [5607475129431807682]];
select [tuple(-4741124612489978151, 1), tuple(-3236599669630092879, 2), tuple(560747512943180768, 3)];
select array(-4741124612489978151, 1, -3236599669630092879, 2, 560747512943180768, 3);
select map(-4741124612489978151, 1, -3236599669630092879, 2, 5607475129431807682, 3);
select [map(-4741124612489978151, 1, -3236599669630092879, 2, 5607475129431807682, 3), map(-1, 1)];
select map(1, -4741124612489978151, 2, -3236599669630092879, 3, 5607475129431807682);
select [map(1, -4741124612489978151, 2, -3236599669630092879, 3, 5607475129431807682), map(-1, 1)];
select if(materialize(1), -1234567890123456789, 1234567890123456789);