supertypes for Decimals [CLICKHOUSE-3765]

This commit is contained in:
chertus 2018-09-06 13:48:54 +03:00
parent 3bd586cad9
commit 5bbfdc2208
3 changed files with 80 additions and 34 deletions

View File

@ -1,3 +1,5 @@
#include <unordered_set>
#include <IO/WriteBufferFromString.h>
#include <IO/Operators.h>
#include <Common/typeid_cast.h>
@ -11,6 +13,7 @@
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
namespace DB
@ -185,22 +188,19 @@ DataTypePtr getLeastSupertype(const DataTypes & types)
/// Non-recursive rules
std::unordered_set<TypeIndex> type_ids;
for (const auto & type : types)
type_ids.insert(type->getTypeId());
/// For String and FixedString, or for different FixedStrings, the common type is String.
/// No other types are compatible with Strings. TODO Enums?
{
bool have_string = false;
bool all_strings = true;
UInt32 have_string = type_ids.count(TypeIndex::String);
UInt32 have_fixed_string = type_ids.count(TypeIndex::FixedString);
for (const auto & type : types)
{
if (type->isStringOrFixedString())
have_string = true;
else
all_strings = false;
}
if (have_string)
if (have_string || have_fixed_string)
{
bool all_strings = type_ids.size() == (have_string + have_fixed_string);
if (!all_strings)
throw Exception(getExceptionMessagePrefix(types) + " because some of them are String/FixedString and some of them are not", ErrorCodes::NO_COMMON_TYPE);
@ -210,19 +210,12 @@ DataTypePtr getLeastSupertype(const DataTypes & types)
/// For Date and DateTime, the common type is DateTime. No other types are compatible.
{
bool have_date_or_datetime = false;
bool all_date_or_datetime = true;
UInt32 have_date = type_ids.count(TypeIndex::Date);
UInt32 have_datetime = type_ids.count(TypeIndex::DateTime);
for (const auto & type : types)
{
if (type->isDateOrDateTime())
have_date_or_datetime = true;
else
all_date_or_datetime = false;
}
if (have_date_or_datetime)
if (have_date || have_datetime)
{
bool all_date_or_datetime = type_ids.size() == (have_date + have_datetime);
if (!all_date_or_datetime)
throw Exception(getExceptionMessagePrefix(types) + " because some of them are Date/DateTime and some of them are not", ErrorCodes::NO_COMMON_TYPE);
@ -230,6 +223,35 @@ DataTypePtr getLeastSupertype(const DataTypes & types)
}
}
/// Decimals
{
UInt32 have_decimal32 = type_ids.count(TypeIndex::Decimal32);
UInt32 have_decimal64 = type_ids.count(TypeIndex::Decimal64);
UInt32 have_decimal128 = type_ids.count(TypeIndex::Decimal128);
if (have_decimal32 || have_decimal64 || have_decimal128)
{
bool all_are_decimals = type_ids.size() == (have_decimal32 + have_decimal64 + have_decimal128);
if (!all_are_decimals)
throw Exception(getExceptionMessagePrefix(types) + " because some of them are Decimals and some are not",
ErrorCodes::NO_COMMON_TYPE);
UInt32 max_scale = 0;
for (const auto & type : types)
{
UInt32 scale = getDecimalScale(*type);
if (scale > max_scale)
max_scale = scale;
}
if (have_decimal128)
return std::make_shared<DataTypeDecimal<Decimal128>>(DataTypeDecimal<Decimal128>::maxPrecision(), max_scale);
if (have_decimal64)
return std::make_shared<DataTypeDecimal<Decimal64>>(DataTypeDecimal<Decimal64>::maxPrecision(), max_scale);
return std::make_shared<DataTypeDecimal<Decimal32>>(DataTypeDecimal<Decimal32>::maxPrecision(), max_scale);
}
}
/// For numeric types, the most complicated part.
{
bool all_numbers = true;

View File

@ -22,6 +22,12 @@ Tuple(Decimal(9, 1), Decimal(18, 1), Decimal(38, 1)) Decimal(9, 1) Decimal(18, 1
[0.100,0.200,0.300,0.000] [0.000,0.100,0.200,0.300]
[0.400,0.500,0.600,0.000] [0.000,0.400,0.500,0.600]
[0.700,0.800,0.900,0.000] [0.000,0.700,0.800,0.900]
[0.100,0.200,0.300,0.000] Array(Decimal(9, 3))
[0.400,0.500,0.600,0.000] Array(Decimal(18, 3))
[0.700,0.800,0.900,0.000] Array(Decimal(38, 3))
[0.0000,0.1000,0.2000,0.3000] Array(Decimal(9, 4))
[0.0000,0.4000,0.5000,0.6000] Array(Decimal(18, 4))
[0.0000,0.7000,0.8000,0.9000] Array(Decimal(38, 4))
3 3 3
2 2 2
0 0 0
@ -36,3 +42,15 @@ Tuple(Decimal(9, 1), Decimal(18, 1), Decimal(38, 1)) Decimal(9, 1) Decimal(18, 1
1 0
2 0
3 0
[0.100,0.200,0.300,0.400,0.500,0.600] Array(Decimal(18, 3))
[0.100,0.200,0.300,0.700,0.800,0.900] Array(Decimal(38, 3))
[0.400,0.500,0.600,0.700,0.800,0.900] Array(Decimal(38, 3))
[0.100,0.200,0.300,1.100,1.200] Array(Decimal(9, 3))
[0.400,0.500,0.600,2.100,2.200] Array(Decimal(18, 3))
[0.700,0.800,0.900,3.100,3.200] Array(Decimal(38, 3))
[0.100,0.200,0.300,2.100,2.200] Array(Decimal(18, 3))
[0.100,0.200,0.300,3.100,3.200] Array(Decimal(38, 3))
[0.400,0.500,0.600,1.100,1.200] Array(Decimal(18, 3))
[0.400,0.500,0.600,3.100,3.200] Array(Decimal(38, 3))
[0.700,0.800,0.900,1.100,1.200] Array(Decimal(38, 3))
[0.700,0.800,0.900,2.100,2.200] Array(Decimal(38, 3))

View File

@ -42,12 +42,12 @@ SELECT arrayPushBack(a, toDecimal32(0, 3)), arrayPushFront(a, toDecimal32(0, 3))
SELECT arrayPushBack(b, toDecimal64(0, 3)), arrayPushFront(b, toDecimal64(0, 3)) FROM test.decimal;
SELECT arrayPushBack(c, toDecimal128(0, 3)), arrayPushFront(c, toDecimal128(0, 3)) FROM test.decimal;
SELECT arrayPushBack(a, toDecimal32(0, 2)) FROM test.decimal; -- { serverError 386 }
SELECT arrayPushBack(b, toDecimal64(0, 2)) FROM test.decimal; -- { serverError 386 }
SELECT arrayPushBack(c, toDecimal128(0, 2)) FROM test.decimal; -- { serverError 386 }
SELECT arrayPushFront(a, toDecimal32(0, 4)) FROM test.decimal; -- { serverError 386 }
SELECT arrayPushFront(b, toDecimal64(0, 4)) FROM test.decimal; -- { serverError 386 }
SELECT arrayPushFront(c, toDecimal128(0, 4)) FROM test.decimal; -- { serverError 386 }
SELECT arrayPushBack(a, toDecimal32(0, 2)) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayPushBack(b, toDecimal64(0, 2)) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayPushBack(c, toDecimal128(0, 2)) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayPushFront(a, toDecimal32(0, 4)) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayPushFront(b, toDecimal64(0, 4)) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayPushFront(c, toDecimal128(0, 4)) AS x, toTypeName(x) FROM test.decimal;
SELECT length(a), length(b), length(c) FROM test.decimal;
SELECT length(nest.a), length(nest.b), length(nest.c) FROM test.decimal;
@ -92,11 +92,17 @@ SELECT indexOf(c, toDecimal64(0.7, 3)) FROM test.decimal; -- { serverError 43 }
SELECT indexOf(c, toDecimal128(0.7, 2)) FROM test.decimal; -- { serverError 43 }
SELECT indexOf(c, toDecimal128(0.7, 4)) FROM test.decimal; -- { serverError 43 }
SELECT arrayConcat(a, b) FROM test.decimal; -- { serverError 386 }
SELECT arrayConcat(a, c) FROM test.decimal; -- { serverError 386 }
SELECT arrayConcat(b, c) FROM test.decimal; -- { serverError 386 }
SELECT arrayConcat(a, nest.a) FROM test.decimal; -- { serverError 386 }
SELECT arrayConcat(b, nest.b) FROM test.decimal; -- { serverError 386 }
SELECT arrayConcat(c, nest.c) FROM test.decimal; -- { serverError 386 }
SELECT arrayConcat(a, b) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayConcat(a, c) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayConcat(b, c) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayConcat(a, nest.a) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayConcat(b, nest.b) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayConcat(c, nest.c) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayConcat(a, nest.b) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayConcat(a, nest.c) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayConcat(b, nest.a) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayConcat(b, nest.c) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayConcat(c, nest.a) AS x, toTypeName(x) FROM test.decimal;
SELECT arrayConcat(c, nest.b) AS x, toTypeName(x) FROM test.decimal;
DROP TABLE IF EXISTS test.decimal;