From 6af36cb703a9ad23b4d152c2fd5febfcf141ad48 Mon Sep 17 00:00:00 2001 From: Artem Zuikov Date: Thu, 18 Jun 2020 13:18:28 +0300 Subject: [PATCH] CAST keep nullable (#11733) --- src/Core/Settings.h | 1 + src/Functions/FunctionsConversion.cpp | 6 ++++++ src/Functions/FunctionsConversion.h | 16 ++++++++++++---- src/Functions/if.cpp | 9 +++++---- src/Interpreters/castColumn.cpp | 3 ++- .../01322_cast_keep_nullable.reference | 10 ++++++++++ .../0_stateless/01322_cast_keep_nullable.sql | 19 +++++++++++++++++++ 7 files changed, 55 insertions(+), 9 deletions(-) create mode 100644 tests/queries/0_stateless/01322_cast_keep_nullable.reference create mode 100644 tests/queries/0_stateless/01322_cast_keep_nullable.sql diff --git a/src/Core/Settings.h b/src/Core/Settings.h index bba258f1d60..a4269b92907 100644 --- a/src/Core/Settings.h +++ b/src/Core/Settings.h @@ -379,6 +379,7 @@ struct Settings : public SettingsCollection \ M(SettingBool, allow_experimental_geo_types, false, "Allow geo data types such as Point, Ring, Polygon, MultiPolygon", 0) \ M(SettingBool, data_type_default_nullable, false, "Data types without NULL or NOT NULL will make Nullable", 0) \ + M(SettingBool, cast_keep_nullable, false, "CAST operator keep Nullable for result data type", 0) \ \ /** Obsolete settings that do nothing but left for compatibility reasons. Remove each one after half a year of obsolescence. */ \ \ diff --git a/src/Functions/FunctionsConversion.cpp b/src/Functions/FunctionsConversion.cpp index 0bd7d1a27e8..bbde6e04069 100644 --- a/src/Functions/FunctionsConversion.cpp +++ b/src/Functions/FunctionsConversion.cpp @@ -1,10 +1,16 @@ #include #include +#include namespace DB { +FunctionOverloadResolverImplPtr CastOverloadResolver::create(const Context & context) +{ + return createImpl(context.getSettingsRef().cast_keep_nullable); +} + void registerFunctionsConversion(FunctionFactory & factory) { factory.registerFunction(); diff --git a/src/Functions/FunctionsConversion.h b/src/Functions/FunctionsConversion.h index 83417a3229b..b23cac8c456 100644 --- a/src/Functions/FunctionsConversion.h +++ b/src/Functions/FunctionsConversion.h @@ -2377,10 +2377,13 @@ public: using MonotonicityForRange = FunctionCast::MonotonicityForRange; static constexpr auto name = "CAST"; - static FunctionOverloadResolverImplPtr create(const Context &) { return createImpl(); } - static FunctionOverloadResolverImplPtr createImpl() { return std::make_unique(); } - CastOverloadResolver() {} + static FunctionOverloadResolverImplPtr create(const Context & context); + static FunctionOverloadResolverImplPtr createImpl(bool keep_nullable) { return std::make_unique(keep_nullable); } + + CastOverloadResolver(bool keep_nullable_) + : keep_nullable(keep_nullable_) + {} String getName() const override { return name; } @@ -2415,13 +2418,18 @@ protected: " Instead there is a column with the following structure: " + column->dumpStructure(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - return DataTypeFactory::instance().get(type_col->getValue()); + DataTypePtr type = DataTypeFactory::instance().get(type_col->getValue()); + if (keep_nullable && arguments.front().type->isNullable()) + return makeNullable(type); + return type; } bool useDefaultImplementationForNulls() const override { return false; } bool useDefaultImplementationForLowCardinalityColumns() const override { return false; } private: + bool keep_nullable; + template static auto monotonicityForType(const DataType * const) { diff --git a/src/Functions/if.cpp b/src/Functions/if.cpp index 02c3d938d2b..c272dc98505 100644 --- a/src/Functions/if.cpp +++ b/src/Functions/if.cpp @@ -693,11 +693,12 @@ private: static ColumnPtr makeNullableColumnIfNot(const ColumnPtr & column) { - if (isColumnNullable(*column)) - return column; + auto materialized = materializeColumnIfConst(column); - return ColumnNullable::create( - materializeColumnIfConst(column), ColumnUInt8::create(column->size(), 0)); + if (isColumnNullable(*materialized)) + return materialized; + + return ColumnNullable::create(materialized, ColumnUInt8::create(column->size(), 0)); } static ColumnPtr getNestedColumn(const ColumnPtr & column) diff --git a/src/Interpreters/castColumn.cpp b/src/Interpreters/castColumn.cpp index 2e6604f7df5..756ccbc6d7e 100644 --- a/src/Interpreters/castColumn.cpp +++ b/src/Interpreters/castColumn.cpp @@ -29,7 +29,8 @@ ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr & type } }; - FunctionOverloadResolverPtr func_builder_cast = std::make_shared(CastOverloadResolver::createImpl()); + FunctionOverloadResolverPtr func_builder_cast = + std::make_shared(CastOverloadResolver::createImpl(false)); ColumnsWithTypeAndName arguments{ temporary_block.getByPosition(0), temporary_block.getByPosition(1) }; auto func_cast = func_builder_cast->build(arguments); diff --git a/tests/queries/0_stateless/01322_cast_keep_nullable.reference b/tests/queries/0_stateless/01322_cast_keep_nullable.reference new file mode 100644 index 00000000000..8ad99a10170 --- /dev/null +++ b/tests/queries/0_stateless/01322_cast_keep_nullable.reference @@ -0,0 +1,10 @@ +0 Int32 +0 Int32 +1 Nullable(Int32) +1 Nullable(Int32) +2 Nullable(Float32) +2 Nullable(UInt8) +3 Nullable(Int32) +\N Nullable(Int32) +42 Nullable(Int32) +\N Nullable(Int32) diff --git a/tests/queries/0_stateless/01322_cast_keep_nullable.sql b/tests/queries/0_stateless/01322_cast_keep_nullable.sql new file mode 100644 index 00000000000..10918717469 --- /dev/null +++ b/tests/queries/0_stateless/01322_cast_keep_nullable.sql @@ -0,0 +1,19 @@ +SET cast_keep_nullable = 0; + +SELECT CAST(toNullable(toInt32(0)) AS Int32) as x, toTypeName(x); +SELECT CAST(toNullable(toInt8(0)) AS Int32) as x, toTypeName(x); + +SET cast_keep_nullable = 1; + +SELECT CAST(toNullable(toInt32(1)) AS Int32) as x, toTypeName(x); +SELECT CAST(toNullable(toInt8(1)) AS Int32) as x, toTypeName(x); + +SELECT CAST(toNullable(toFloat32(2)), 'Float32') as x, toTypeName(x); +SELECT CAST(toNullable(toFloat32(2)), 'UInt8') as x, toTypeName(x); +SELECT CAST(toNullable(toFloat32(2)), 'UUID') as x, toTypeName(x); -- { serverError 70 } + +SELECT CAST(if(1 = 1, toNullable(toInt8(3)), NULL) AS Int32) as x, toTypeName(x); +SELECT CAST(if(1 = 0, toNullable(toInt8(3)), NULL) AS Int32) as x, toTypeName(x); + +SELECT CAST(a, 'Int32') as x, toTypeName(x) FROM (SELECT materialize(CAST(42, 'Nullable(UInt8)')) AS a); +SELECT CAST(a, 'Int32') as x, toTypeName(x) FROM (SELECT materialize(CAST(NULL, 'Nullable(UInt8)')) AS a);