Updated usage of different types during IN query

1. Added accurateCast function.
2. Use accurateCast in Set during execute.
3. Added accurateCast tests.
4. Updated select_in_different_types tests.
This commit is contained in:
Maksim Kita 2020-12-05 16:18:56 +03:00
parent 6f4bc77dbc
commit 0464859cfe
15 changed files with 294 additions and 231 deletions

View File

@ -430,7 +430,28 @@ SELECT toTypeName(CAST(x, 'Nullable(UInt16)')) FROM t_null
- [cast_keep_nullable](../../operations/settings/settings.md#cast_keep_nullable) setting
## accurateCastOrNull(x, T) {#type_conversion_function-accurate-cast}
## accurateCast(x, T) {#type_conversion_function-accurate-cast}
Converts x to the t data type. The differente from cast(x, T) is that accurateCast
does not allow overflow of numeric types during cast if type value x does not fit
bounds of type T.
Example
``` sql
SELECT cast(-1, 'UInt8') as uint8;
```
``` text
```
```sql
SELECT accurateCast(-1, 'UInt8') as uint8;
```
``` text
```
## accurateCastOrNull(x, T) {#type_conversion_function-accurate-cast_or_null}
Converts x to the t data type. Always returns nullable type and returns NULL
if the casted value is not representable in the target type.

View File

@ -33,7 +33,7 @@ struct AvgFraction
/// Allow division by zero as sometimes we need to return NaN.
/// Invoked only is either Numerator or Denominator are Decimal.
Float64 NO_SANITIZE_UNDEFINED divideIfAnyDecimal(UInt32 num_scale, UInt32 denom_scale [[maybe_unused]]) const
Float64 NO_SANITIZE_UNDEFINED divideIfAnyDecimal(UInt32 num_scale, UInt32 denom_scale [[maybe_unused]]) const
{
if constexpr (IsDecimalNumber<Numerator> && IsDecimalNumber<Denominator>)
{

View File

@ -515,23 +515,34 @@ inline bool NO_SANITIZE_UNDEFINED convertNumeric(From value, To & result)
return true;
}
if constexpr (std::is_floating_point_v<From> && std::is_floating_point_v<To>) {
if constexpr (std::is_floating_point_v<From> && std::is_floating_point_v<To>)
{
/// Note that NaNs doesn't compare equal to anything, but they are still in range of any Float type.
if (isNaN(value))
{
result = value;
return true;
}
if (value == std::numeric_limits<From>::infinity())
{
result = std::numeric_limits<To>::infinity();
return true;
}
if (value == -std::numeric_limits<From>::infinity())
{
result = -std::numeric_limits<To>::infinity();
return true;
}
}
if (accurate::greaterOp(value, std::numeric_limits<To>::max())
|| accurate::greaterOp(std::numeric_limits<To>::min(), value))
|| accurate::greaterOp(std::numeric_limits<To>::lowest(), value))
{
return false;
}
result = static_cast<To>(value);
return equalsOp(value, result);
}
}

View File

@ -1,21 +1,10 @@
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionsConversion.h>
#include <Interpreters/Context.h>
namespace DB
{
FunctionOverloadResolverImplPtr CastOverloadResolver::create(const Context & context)
{
return createImpl(context.getSettingsRef().cast_keep_nullable);
}
FunctionOverloadResolverImplPtr AccurateCastOverloadResolver::create(const Context &)
{
return std::make_unique<AccurateCastOverloadResolver>();
}
void registerFunctionFixedString(FunctionFactory & factory);
void registerFunctionsConversion(FunctionFactory & factory)
@ -49,8 +38,10 @@ void registerFunctionsConversion(FunctionFactory & factory)
registerFunctionFixedString(factory);
factory.registerFunction<FunctionToUnixTimestamp>();
factory.registerFunction<CastOverloadResolver>(FunctionFactory::CaseInsensitive);
factory.registerFunction<AccurateCastOverloadResolver>();
factory.registerFunction<CastOverloadResolver<CastType::nonAccurate>>(FunctionFactory::CaseInsensitive);
factory.registerFunction<CastOverloadResolver<CastType::accurate>>();
factory.registerFunction<CastOverloadResolver<CastType::accurateOrNull>>();
factory.registerFunction<FunctionToUInt8OrZero>();
factory.registerFunction<FunctionToUInt16OrZero>();

View File

@ -41,9 +41,10 @@
#include <Functions/FunctionsMiscellaneous.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/DateTimeTransforms.h>
#include <Functions/toFixedString.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <Columns/ColumnLowCardinality.h>
#include <Functions/toFixedString.h>
#include <Interpreters/Context.h>
namespace DB
@ -70,6 +71,7 @@ namespace ErrorCodes
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NOT_IMPLEMENTED;
extern const int CANNOT_INSERT_NULL_IN_ORDINARY_COLUMN;
extern const int DECIMAL_OVERFLOW;
}
@ -95,15 +97,14 @@ inline UInt32 extractToDecimalScale(const ColumnWithTypeAndName & named_column)
/// Function toUnixTimestamp has exactly the same implementation as toDateTime of String type.
struct NameToUnixTimestamp { static constexpr auto name = "toUnixTimestamp"; };
struct AccurateAdditions
struct AccurateConvertStrategyAdditions
{
UInt32 scale { 0 };
};
enum class ConvertStrategy
struct AccurateOrNullConvertStrategyAdditions
{
NonAccurate,
Accurate
UInt32 scale { 0 };
};
/** Conversion of number types to each other, enums to numbers, dates and datetimes to numbers and back: done by straight assignment.
@ -117,11 +118,9 @@ struct ConvertImpl
template <typename Additions = void *>
static ColumnPtr NO_SANITIZE_UNDEFINED execute(
const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t /*input_rows_count*/,
const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type [[maybe_unused]], size_t /*input_rows_count*/,
Additions additions [[maybe_unused]] = Additions())
{
static constexpr auto convert_strategy
= std::is_same_v<Additions, AccurateAdditions> ? ConvertStrategy::Accurate : ConvertStrategy::NonAccurate;
const ColumnWithTypeAndName & named_from = arguments[0];
using ColVecFrom = typename FromDataType::ColumnType;
@ -150,7 +149,8 @@ struct ConvertImpl
if constexpr (IsDataTypeDecimal<ToDataType>)
{
UInt32 scale;
if constexpr (convert_strategy == ConvertStrategy::Accurate)
if constexpr (std::is_same_v<Additions, AccurateConvertStrategyAdditions>
|| std::is_same_v<Additions, AccurateOrNullConvertStrategyAdditions>)
{
scale = additions.scale;
}
@ -171,7 +171,7 @@ struct ConvertImpl
ColumnUInt8::MutablePtr col_null_map_to;
ColumnUInt8::Container * vec_null_map_to [[maybe_unused]] = nullptr;
if constexpr (convert_strategy == ConvertStrategy::Accurate)
if constexpr (std::is_same_v<Additions, AccurateOrNullConvertStrategyAdditions>)
{
col_null_map_to = ColumnUInt8::create(size, false);
vec_null_map_to = &col_null_map_to->getData();
@ -179,95 +179,91 @@ struct ConvertImpl
for (size_t i = 0; i < size; ++i)
{
if constexpr (IsDataTypeDecimal<FromDataType> || IsDataTypeDecimal<ToDataType>)
if constexpr ((is_big_int_v<FromFieldType> || is_big_int_v<ToFieldType>)
&& (std::is_same_v<FromFieldType, UInt128> || std::is_same_v<ToFieldType, UInt128>))
{
try
{
if constexpr (IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>)
vec_to[i] = convertDecimals<FromDataType, ToDataType>(vec_from[i], vec_from.getScale(), vec_to.getScale());
else if constexpr (IsDataTypeDecimal<FromDataType> && IsDataTypeNumber<ToDataType>)
vec_to[i] = convertFromDecimal<FromDataType, ToDataType>(vec_from[i], vec_from.getScale());
else if constexpr (IsDataTypeNumber<FromDataType> && IsDataTypeDecimal<ToDataType>)
vec_to[i] = convertToDecimal<FromDataType, ToDataType>(vec_from[i], vec_to.getScale());
else
{
throw Exception("Unsupported data type in conversion function", ErrorCodes::CANNOT_CONVERT_TYPE);
}
}
catch (...)
{
/// Handle decimal overflow that propagated as exception
if constexpr (convert_strategy == ConvertStrategy::Accurate)
(*vec_null_map_to)[i] = true;
else
throw;
}
if constexpr (std::is_same_v<Additions, AccurateOrNullConvertStrategyAdditions>)
(*vec_null_map_to)[i] = true;
else
throw Exception("Unexpected UInt128 to big int conversion", ErrorCodes::NOT_IMPLEMENTED);
}
else if constexpr (is_big_int_v<FromFieldType> || is_big_int_v<ToFieldType>)
{
if constexpr (std::is_same_v<FromFieldType, UInt128> || std::is_same_v<ToFieldType, UInt128>)
else {
if constexpr (IsDataTypeDecimal<FromDataType> || IsDataTypeDecimal<ToDataType>)
{
if constexpr (convert_strategy == ConvertStrategy::Accurate)
(*vec_null_map_to)[i] = true;
else
throw Exception("Unexpected UInt128 to big int conversion", ErrorCodes::NOT_IMPLEMENTED);
}
/// If From Data is Nan or Inf, throw exception
else if (!isFinite(vec_from[i]))
{
if constexpr (convert_strategy == ConvertStrategy::Accurate)
vec_null_map_to[i] = true;
else
throw Exception("Unexpected inf or nan to big int conversion", ErrorCodes::NOT_IMPLEMENTED);
try
{
if constexpr (IsDataTypeDecimal<FromDataType> && IsDataTypeDecimal<ToDataType>)
vec_to[i] = convertDecimals<FromDataType, ToDataType>(vec_from[i], vec_from.getScale(), vec_to.getScale());
else if constexpr (IsDataTypeDecimal<FromDataType> && IsDataTypeNumber<ToDataType>)
vec_to[i] = convertFromDecimal<FromDataType, ToDataType>(vec_from[i], vec_from.getScale());
else if constexpr (IsDataTypeNumber<FromDataType> && IsDataTypeDecimal<ToDataType>)
vec_to[i] = convertToDecimal<FromDataType, ToDataType>(vec_from[i], vec_to.getScale());
else
{
throw Exception("Unsupported data type in conversion function", ErrorCodes::CANNOT_CONVERT_TYPE);
}
}
catch (const Exception & exception)
{
if constexpr (std::is_same_v<Additions, AccurateOrNullConvertStrategyAdditions>)
{
if (exception.code() == ErrorCodes::CANNOT_CONVERT_TYPE || exception.code() == ErrorCodes::DECIMAL_OVERFLOW)
(*vec_null_map_to)[i] = true;
else
throw exception;
}
else
throw exception;
}
}
else
{
if constexpr (convert_strategy == ConvertStrategy::Accurate)
/// If From Data is Nan or Inf, throw exception
/// TODO: Probably this can be applied to all integers not just big integers
/// https://stackoverflow.com/questions/38795544/is-casting-of-infinity-to-integer-undefined
if constexpr (is_big_int_v<ToFieldType>)
{
if (accurate::greaterOp(vec_from[i], std::numeric_limits<ToFieldType>::max())
|| accurate::greaterOp(std::numeric_limits<ToFieldType>::min(), vec_from[i]))
if (!isFinite(vec_from[i]))
{
(*vec_null_map_to)[i] = true;
continue;
if constexpr (std::is_same_v<Additions, AccurateOrNullConvertStrategyAdditions>)
{
vec_null_map_to[i] = true;
continue;
}
else
throw Exception("Unexpected inf or nan to big int conversion", ErrorCodes::NOT_IMPLEMENTED);
}
}
ToFieldType from = bigint_cast<ToFieldType>(vec_from[i]);
vec_to[i] = static_cast<ToFieldType>(from);
}
}
else if constexpr (std::is_same_v<ToFieldType, UInt128> && sizeof(FromFieldType) <= sizeof(UInt64))
{
if constexpr (convert_strategy == ConvertStrategy::Accurate)
{
if (accurate::greaterOp(vec_from[i], std::numeric_limits<ToFieldType>::max())
|| accurate::greaterOp(std::numeric_limits<ToFieldType>::min(), vec_from[i]))
if constexpr (std::is_same_v<Additions, AccurateOrNullConvertStrategyAdditions>
|| std::is_same_v<Additions, AccurateConvertStrategyAdditions>)
{
(*vec_null_map_to)[i] = true;
continue;
bool convert_result = accurate::convertNumeric(vec_from[i], vec_to[i]);
if (!convert_result)
{
if (std::is_same_v<Additions, AccurateOrNullConvertStrategyAdditions>)
{
(*vec_null_map_to)[i] = true;
}
else
{
throw Exception(
"Value in column " + named_from.column->getName() + " cannot be safely converted into type "
+ result_type->getName(),
ErrorCodes::CANNOT_CONVERT_TYPE);
}
}
}
else
{
vec_to[i] = static_cast<ToFieldType>(vec_from[i]);
}
}
UInt64 from = static_cast<UInt64>(vec_from[i]);
vec_to[i] = static_cast<ToFieldType>(from);
}
else
{
if constexpr (convert_strategy == ConvertStrategy::Accurate)
{
if (accurate::greaterOp(vec_from[i], std::numeric_limits<ToFieldType>::max())
|| accurate::greaterOp(std::numeric_limits<ToFieldType>::min(), vec_from[i]))
{
(*vec_null_map_to)[i] = true;
continue;
}
}
vec_to[i] = static_cast<ToFieldType>(vec_from[i]);
}
}
if constexpr (convert_strategy == ConvertStrategy::Accurate)
if constexpr (std::is_same_v<Additions, AccurateOrNullConvertStrategyAdditions>)
return ColumnNullable::create(std::move(col_to), std::move(col_null_map_to));
else
return col_to;
@ -2028,6 +2024,12 @@ private:
struct NameCast { static constexpr auto name = "CAST"; };
enum class CastType {
nonAccurate,
accurate,
accurateOrNull
};
class FunctionCast final : public IFunctionBaseImpl
{
public:
@ -2037,10 +2039,10 @@ public:
FunctionCast(const char * name_, MonotonicityForRange && monotonicity_for_range_
, const DataTypes & argument_types_, const DataTypePtr & return_type_
, std::optional<Diagnostic> diagnostic_, bool is_accurate_cast_or_null_)
, std::optional<Diagnostic> diagnostic_, CastType cast_type_)
: name(name_), monotonicity_for_range(std::move(monotonicity_for_range_))
, argument_types(argument_types_), return_type(return_type_), diagnostic(std::move(diagnostic_))
, is_accurate_cast_or_null(is_accurate_cast_or_null_)
, cast_type(cast_type_)
{
}
@ -2087,7 +2089,7 @@ private:
DataTypePtr return_type;
std::optional<Diagnostic> diagnostic;
bool is_accurate_cast_or_null;
CastType cast_type;
WrapperType createFunctionAdaptor(FunctionPtr function, const DataTypePtr & from_type) const
{
@ -2101,7 +2103,7 @@ private:
};
}
WrapperType createToNullableColumnWrapper() const
static WrapperType createToNullableColumnWrapper()
{
return [] (ColumnsWithTypeAndName &, const DataTypePtr & result_type, const ColumnNullable *, size_t input_rows_count)
{
@ -2116,7 +2118,8 @@ private:
{
TypeIndex from_type_index = from_type->getTypeId();
WhichDataType which(from_type_index);
bool can_apply_accurate_cast = is_accurate_cast_or_null && (which.isInt() || which.isUInt() || which.isFloat());
bool can_apply_accurate_cast = (cast_type == CastType::accurate || cast_type == CastType::accurateOrNull)
&& (which.isInt() || which.isUInt() || which.isFloat());
if (requested_result_is_nullable && checkAndGetDataType<DataTypeString>(from_type.get()))
{
@ -2131,10 +2134,9 @@ private:
return createFunctionAdaptor(function, from_type);
}
auto nullable_column_wrapper = createToNullableColumnWrapper();
bool is_accurate_cast = is_accurate_cast_or_null;
auto wrapper_cast_type = cast_type;
return [is_accurate_cast, nullable_column_wrapper, from_type_index, to_type]
return [wrapper_cast_type, from_type_index, to_type]
(ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, const ColumnNullable *column_nullable, size_t input_rows_count)
{
ColumnPtr result_column;
@ -2145,8 +2147,17 @@ private:
if constexpr (IsDataTypeNumber<LeftDataType> && IsDataTypeNumber<RightDataType>)
{
result_column = ConvertImpl<LeftDataType, RightDataType, NameCast>::execute(
arguments, result_type, input_rows_count, AccurateAdditions());
if (wrapper_cast_type == CastType::accurate)
{
result_column = ConvertImpl<LeftDataType, RightDataType, NameCast>::execute(
arguments, result_type, input_rows_count, AccurateConvertStrategyAdditions());
}
else
{
result_column = ConvertImpl<LeftDataType, RightDataType, NameCast>::execute(
arguments, result_type, input_rows_count, AccurateOrNullConvertStrategyAdditions());
}
return true;
}
@ -2156,8 +2167,11 @@ private:
/// Additionally check if callOnIndexAndDataType wasn't called at all.
if (!res)
{
if (is_accurate_cast)
if (wrapper_cast_type == CastType::accurateOrNull)
{
auto nullable_column_wrapper = FunctionCast::createToNullableColumnWrapper();
return nullable_column_wrapper(arguments, result_type, column_nullable, input_rows_count);
}
else
{
throw Exception{"Conversion from " + std::string(getTypeName(from_type_index)) + " to " + to_type->getName() + " is not supported",
@ -2180,7 +2194,7 @@ private:
if (!isStringOrFixedString(from_type))
throw Exception{"CAST AS FixedString is only implemented for types String and FixedString", ErrorCodes::NOT_IMPLEMENTED};
bool exception_mode_null = is_accurate_cast_or_null;
bool exception_mode_null = cast_type == CastType::accurateOrNull;
return [exception_mode_null, N] (ColumnsWithTypeAndName & arguments, const DataTypePtr &, const ColumnNullable *, size_t /*input_rows_count*/)
{
if (exception_mode_null)
@ -2202,17 +2216,16 @@ private:
|| which.isStringOrFixedString();
if (!ok)
{
if (is_accurate_cast_or_null)
if (cast_type == CastType::accurateOrNull)
return createToNullableColumnWrapper();
else
throw Exception{"Conversion from " + from_type->getName() + " to " + to_type->getName() + " is not supported",
ErrorCodes::CANNOT_CONVERT_TYPE};
}
auto nullable_column_wrapper = createToNullableColumnWrapper();
bool is_accurate_cast = is_accurate_cast_or_null;
auto wrapper_cast_type = cast_type;
return [is_accurate_cast, nullable_column_wrapper, type_index, scale, to_type]
return [wrapper_cast_type, type_index, scale, to_type]
(ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, const ColumnNullable *column_nullable, size_t input_rows_count)
{
ColumnPtr result_column;
@ -2223,11 +2236,22 @@ private:
if constexpr (IsDataTypeDecimalOrNumber<LeftDataType> && IsDataTypeDecimalOrNumber<RightDataType>)
{
if (is_accurate_cast)
if (wrapper_cast_type == CastType::accurate)
{
AccurateAdditions additions;
AccurateConvertStrategyAdditions additions;
additions.scale = scale;
result_column = ConvertImpl<LeftDataType, RightDataType, NameCast>::execute(arguments, result_type, input_rows_count, additions);
result_column = ConvertImpl<LeftDataType, RightDataType, NameCast>::execute(
arguments, result_type, input_rows_count, additions);
return true;
}
else if (wrapper_cast_type == CastType::accurateOrNull)
{
AccurateOrNullConvertStrategyAdditions additions;
additions.scale = scale;
result_column = ConvertImpl<LeftDataType, RightDataType, NameCast>::execute(
arguments, result_type, input_rows_count, additions);
return true;
}
}
@ -2240,8 +2264,11 @@ private:
/// Additionally check if callOnIndexAndDataType wasn't called at all.
if (!res)
{
if (is_accurate_cast)
if (wrapper_cast_type == CastType::accurateOrNull)
{
auto nullable_column_wrapper = FunctionCast::createToNullableColumnWrapper();
return nullable_column_wrapper(arguments, result_type, column_nullable, input_rows_count);
}
else
throw Exception{"Conversion from " + std::string(getTypeName(type_index)) + " to " + to_type->getName() + " is not supported",
ErrorCodes::CANNOT_CONVERT_TYPE};
@ -2263,7 +2290,7 @@ private:
}
else
{
if (is_accurate_cast_or_null)
if (cast_type == CastType::accurateOrNull)
return createToNullableColumnWrapper();
else
throw Exception{"Conversion from " + from_type_untyped->getName() + " to " + to_type->getName() +
@ -2397,7 +2424,7 @@ private:
}
else
{
if (is_accurate_cast_or_null)
if (cast_type == CastType::accurateOrNull)
return createToNullableColumnWrapper();
else
throw Exception{"Conversion from " + from_type->getName() + " to " + to_type->getName() + " is not supported",
@ -2504,7 +2531,7 @@ private:
{
if (!to_nested->isNullable())
{
if (is_accurate_cast_or_null)
if (cast_type == CastType::accurateOrNull)
{
return createToNullableColumnWrapper();
}
@ -2750,7 +2777,7 @@ private:
break;
}
if (is_accurate_cast_or_null)
if (cast_type == CastType::accurateOrNull)
return createToNullableColumnWrapper();
else
throw Exception{"Conversion from " + from_type->getName() + " to " + to_type->getName() + " is not supported",
@ -2815,15 +2842,26 @@ template <typename DataType>
}
};
template<CastType cast_type>
class CastOverloadResolver : public IFunctionOverloadResolverImpl
{
public:
using MonotonicityForRange = FunctionCast::MonotonicityForRange;
using Diagnostic = FunctionCast::Diagnostic;
static constexpr auto name = "CAST";
static constexpr auto accurate_cast_name = "accurateCast";
static constexpr auto accurate_cast_or_null_name = "accurateCastOrNull";
static constexpr auto cast_name = "CAST";
static constexpr auto name =
cast_type == CastType::accurate ? accurate_cast_name :
(cast_type == CastType::accurateOrNull ? accurate_cast_or_null_name : cast_name);
static FunctionOverloadResolverImplPtr create(const Context & context)
{
return createImpl(context.getSettingsRef().cast_keep_nullable);
}
static FunctionOverloadResolverImplPtr create(const Context & context);
static FunctionOverloadResolverImplPtr createImpl(bool keep_nullable, std::optional<Diagnostic> diagnostic = {})
{
return std::make_unique<CastOverloadResolver>(keep_nullable, std::move(diagnostic));
@ -2840,8 +2878,6 @@ public:
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
protected:
FunctionBaseImplPtr build(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override
{
DataTypes data_types(arguments.size());
@ -2850,7 +2886,7 @@ protected:
data_types[i] = arguments[i].type;
auto monotonicity = MonotonicityHelper::getMonotonicityInformation(arguments.front().type, return_type.get());
return std::make_unique<FunctionCast>(name, std::move(monotonicity), data_types, return_type, diagnostic, false);
return std::make_unique<FunctionCast>(name, std::move(monotonicity), data_types, return_type, diagnostic, cast_type);
}
DataTypePtr getReturnType(const ColumnsWithTypeAndName & arguments) const override
@ -2868,9 +2904,17 @@ protected:
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
DataTypePtr type = DataTypeFactory::instance().get(type_col->getValue<String>());
if (keep_nullable && arguments.front().type->isNullable())
if constexpr (cast_type == CastType::accurateOrNull)
{
return makeNullable(type);
return type;
}
else
{
if (keep_nullable && arguments.front().type->isNullable())
return makeNullable(type);
return type;
}
}
bool useDefaultImplementationForNulls() const override { return false; }
@ -2881,53 +2925,4 @@ private:
std::optional<Diagnostic> diagnostic;
};
class AccurateCastOverloadResolver : public IFunctionOverloadResolverImpl
{
public:
using MonotonicityForRange = FunctionCast::MonotonicityForRange;
using Diagnostic = FunctionCast::Diagnostic;
static constexpr auto name = "accurateCastOrNull";
static FunctionOverloadResolverImplPtr create(const Context & context);
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
FunctionBaseImplPtr build(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override
{
DataTypes data_types(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i)
data_types[i] = arguments[i].type;
auto monotonicity = MonotonicityHelper::getMonotonicityInformation(arguments.front().type, return_type.get());
return std::make_unique<FunctionCast>(name, std::move(monotonicity), data_types, return_type, std::optional<Diagnostic>(), true);
}
protected:
DataTypePtr getReturnType(const ColumnsWithTypeAndName & arguments) const override
{
const auto & column = arguments.back().column;
if (!column)
throw Exception("Second argument to " + getName() + " must be a constant string describing type."
" Instead there is non-constant column of type " + arguments.back().type->getName(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const auto type_col = checkAndGetColumnConst<ColumnString>(column.get());
if (!type_col)
throw Exception("Second argument to " + getName() + " must be a constant string describing type."
" Instead there is a column with the following structure: " + column->dumpStructure(),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
DataTypePtr type = DataTypeFactory::instance().get(type_col->getValue<String>());
return makeNullable(type);
}
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForLowCardinalityColumns() const override { return false; }
};
}

View File

@ -662,10 +662,10 @@ ActionsDAGPtr ActionsDAG::makeConvertingActions(
auto * right_arg = const_cast<Node *>(&actions_dag->addColumn(std::move(column), true));
auto * left_arg = src_node;
CastOverloadResolver::Diagnostic diagnostic = {src_node->result_name, res_elem.name};
FunctionCast::Diagnostic diagnostic = {src_node->result_name, res_elem.name};
FunctionOverloadResolverPtr func_builder_cast =
std::make_shared<FunctionOverloadResolverAdaptor>(
CastOverloadResolver::createImpl(false, std::move(diagnostic)));
CastOverloadResolver<CastType::nonAccurate>::createImpl(false, std::move(diagnostic)));
Inputs children = { left_arg, right_arg };
src_node = &actions_dag->addFunction(func_builder_cast, std::move(children), {}, true);

View File

@ -25,6 +25,7 @@
#include <Interpreters/evaluateConstantExpression.h>
#include <Interpreters/NullableUtils.h>
#include <Interpreters/sortBlock.h>
#include <Interpreters/castColumn.h>
#include <Interpreters/Context.h>
#include <Storages/MergeTree/KeyCondition.h>
@ -32,8 +33,6 @@
#include <ext/range.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <iostream>
namespace DB
{
@ -257,15 +256,19 @@ ColumnPtr Set::execute(const Block & block, bool negative) const
for (size_t i = 0; i < num_key_columns; ++i)
{
/// TODO: Optimize making cast only if types different
/// TODO: This case SELECT '1' IN (SELECT 1); should not work but with AccurateCastOrNull it works
ColumnPtr result;
auto & column_before_cast = block.safeGetByPosition(i);
ColumnWithTypeAndName column
ColumnWithTypeAndName column_to_cast
= {column_before_cast.column->convertToFullColumnIfConst(), column_before_cast.type, column_before_cast.name};
auto accurate_cast = AccurateCastOverloadResolver().build({column}, data_types[i]);
auto accurate_cast_executable = accurate_cast->prepare({column});
auto casted_column = accurate_cast_executable->execute({column}, data_types[i], column.column->size());
materialized_columns.emplace_back() = casted_column;
if (!transform_null_in) {
result = castColumn<CastType::accurateOrNull>(column_to_cast, data_types[i]);
} else {
result = castColumn<CastType::accurate>(column_to_cast, data_types[i]);
}
materialized_columns.emplace_back() = result;
key_columns.emplace_back() = materialized_columns.back().get();
}

View File

@ -1,34 +0,0 @@
#include <Core/Field.h>
#include <Interpreters/castColumn.h>
#include <Interpreters/ExpressionActions.h>
#include <DataTypes/DataTypeString.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/FunctionsConversion.h>
namespace DB
{
ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr & type)
{
if (arg.type->equals(*type))
return arg.column;
ColumnsWithTypeAndName arguments
{
arg,
{
DataTypeString().createColumnConst(arg.column->size(), type->getName()),
std::make_shared<DataTypeString>(),
""
}
};
FunctionOverloadResolverPtr func_builder_cast =
std::make_shared<FunctionOverloadResolverAdaptor>(CastOverloadResolver::createImpl(false));
auto func_cast = func_builder_cast->build(arguments);
return func_cast->execute(arguments, type, arg.column->size());
}
}

View File

@ -2,8 +2,41 @@
#include <Core/ColumnWithTypeAndName.h>
#include <Functions/FunctionsConversion.h>
namespace DB
{
ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr & type);
template<CastType cast_type = CastType::nonAccurate>
ColumnPtr castColumn(const ColumnWithTypeAndName & arg, const DataTypePtr & type)
{
if (arg.type->equals(*type))
return arg.column;
ColumnsWithTypeAndName arguments
{
arg,
{
DataTypeString().createColumnConst(arg.column->size(), type->getName()),
std::make_shared<DataTypeString>(),
""
}
};
FunctionOverloadResolverPtr func_builder_cast =
std::make_shared<FunctionOverloadResolverAdaptor>(CastOverloadResolver<cast_type>::createImpl(false));
auto func_cast = func_builder_cast->build(arguments);
if constexpr (cast_type == CastType::accurateOrNull)
{
return func_cast->execute(arguments, makeNullable(type), arg.column->size());
}
else
{
return func_cast->execute(arguments, type, arg.column->size());
}
}
}

View File

@ -143,7 +143,6 @@ SRCS(
TreeRewriter.cpp
addMissingDefaults.cpp
addTypeConversionToAST.cpp
castColumn.cpp
convertFieldToType.cpp
createBlockSelector.cpp
evaluateConstantExpression.cpp

View File

@ -1137,7 +1137,7 @@ bool KeyCondition::tryParseAtomFromAST(const ASTPtr & node, const Context & cont
ColumnsWithTypeAndName arguments{
{nullptr, key_expr_type, ""}, {DataTypeString().createColumnConst(1, common_type->getName()), common_type, ""}};
FunctionOverloadResolverPtr func_builder_cast
= std::make_shared<FunctionOverloadResolverAdaptor>(CastOverloadResolver::createImpl(false));
= std::make_shared<FunctionOverloadResolverAdaptor>(CastOverloadResolver<CastType::nonAccurate>::createImpl(false));
auto func_cast = func_builder_cast->build(arguments);
/// If we know the given range only contains one value, then we treat all functions as positive monotonic.

View File

@ -1,6 +1,11 @@
1
0
1
1
2
2
1
1
1
0
0

View File

@ -1,5 +1,7 @@
-- SELECT 1 IN (SELECT -1)
-- SELECT -1 IN (SELECT 1)
SELECT 1 IN (SELECT 1);
SELECT -1 IN (SELECT 1);
DROP TABLE IF EXISTS select_in_test;
CREATE TABLE select_in_test(value UInt8) ENGINE=TinyLog;
INSERT INTO select_in_test VALUES (1), (2), (3);
@ -24,5 +26,10 @@ SELECT value FROM select_in_test WHERE value IN (SELECT 2);
DROP TABLE select_in_test;
SELECT 1 IN (1);
-- Right now this working because of accurate cast. Need to discuss.
SELECT '1' IN (SELECT 1);
SELECT 1 IN (SELECT 1) SETTINGS transform_null_in = 1;
SELECT 1 IN (SELECT 'a') SETTINGS transform_null_in = 1;
SELECT 'a' IN (SELECT 1) SETTINGS transform_null_in = 1; -- { serverError 6 }
SELECT 1 IN (SELECT -1) SETTINGS transform_null_in = 1;
SELECT -1 IN (SELECT 1) SETTINGS transform_null_in = 1; -- { serverError 70 }

View File

@ -0,0 +1,8 @@
5
5
5
5
5
5
1.000000000
12

View File

@ -0,0 +1,24 @@
SELECT accurateCast(-1, 'UInt8'); -- { serverError 70 }
SELECT accurateCast(5, 'UInt8');
SELECT accurateCast(257, 'UInt8'); -- { serverError 70 }
SELECT accurateCast(-1, 'UInt16'); -- { serverError 70 }
SELECT accurateCast(5, 'UInt16');
SELECT accurateCast(65536, 'UInt16'); -- { serverError 70 }
SELECT accurateCast(-1, 'UInt32'); -- { serverError 70 }
SELECT accurateCast(5, 'UInt32');
SELECT accurateCast(4294967296, 'UInt32'); -- { serverError 70 }
SELECT accurateCast(-1, 'UInt64'); -- { serverError 70 }
SELECT accurateCast(5, 'UInt64');
SELECT accurateCast(-1, 'UInt256'); -- { serverError 70 }
SELECT accurateCast(5, 'UInt256');
SELECT accurateCast(-129, 'Int8'); -- { serverError 70 }
SELECT accurateCast(5, 'Int8');
SELECT accurateCast(128, 'Int8'); -- { serverError 70 }
SELECT accurateCast(10, 'Decimal32(9)'); -- { serverError 407 }
SELECT accurateCast(1, 'Decimal32(9)');
SELECT accurateCast(-10, 'Decimal32(9)'); -- { serverError 407 }
SELECT accurateCast('123', 'FixedString(2)'); -- { serverError 131 }
SELECT accurateCast('12', 'FixedString(2)');