#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int BAD_ARGUMENTS; extern const int ILLEGAL_TYPE_OF_ARGUMENT; } class FunctionCastOrDefault final : public IFunction { public: static constexpr auto name = "accurateCastOrDefault"; static FunctionPtr create(ContextPtr context) { return std::make_shared(context); } explicit FunctionCastOrDefault(ContextPtr context_) : keep_nullable(context_->getSettingsRef().cast_keep_nullable) { } String getName() const override { return name; } size_t getNumberOfArguments() const override { return 0; } bool isVariadic() const override { return true; } ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } bool useDefaultImplementationForNulls() const override { return false; } bool useDefaultImplementationForConstants() const override { return false; } bool useDefaultImplementationForLowCardinalityColumns() const override { return true; } bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { size_t arguments_size = arguments.size(); if (arguments_size != 2 && arguments_size != 3) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} expected 2 or 3 arguments. Actual {}", getName(), arguments_size); const auto & type_column = arguments[1].column; const auto * type_column_typed = checkAndGetColumnConst(type_column.get()); if (!type_column_typed) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument to {} must be a constant string describing type. Actual {}", getName(), arguments[1].type->getName()); DataTypePtr result_type = DataTypeFactory::instance().get(type_column_typed->getValue()); if (keep_nullable && arguments.front().type->isNullable()) result_type = makeNullable(result_type); if (arguments.size() == 3) { auto default_value_type = arguments[2].type; if (!areTypesEqual(result_type, default_value_type)) { throw Exception(ErrorCodes::BAD_ARGUMENTS, "Default value type should be same as cast type. Expected {}. Actual {}", result_type->getName(), default_value_type->getName()); } } return result_type; } ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type, size_t) const override { const ColumnWithTypeAndName & column_to_cast = arguments[0]; auto non_const_column_to_cast = column_to_cast.column->convertToFullColumnIfConst(); ColumnWithTypeAndName column_to_cast_non_const { std::move(non_const_column_to_cast), column_to_cast.type, column_to_cast.name }; auto cast_result = castColumnAccurateOrNull(column_to_cast_non_const, return_type); const auto & cast_result_nullable = assert_cast(*cast_result); const auto & null_map_data = cast_result_nullable.getNullMapData(); size_t null_map_data_size = null_map_data.size(); const auto & nested_column = cast_result_nullable.getNestedColumn(); auto result = return_type->createColumn(); result->reserve(null_map_data_size); ColumnNullable * result_nullable = nullptr; if (result->isNullable()) result_nullable = assert_cast(&*result); size_t start_insert_index = 0; Field default_value; ColumnPtr default_column; if (arguments.size() == 3) { auto default_values_column = arguments[2].column; if (isColumnConst(*default_values_column)) default_value = (*default_values_column)[0]; else default_column = default_values_column->convertToFullColumnIfConst(); } else { default_value = return_type->getDefault(); } for (size_t i = 0; i < null_map_data_size; ++i) { bool is_current_index_null = null_map_data[i]; if (!is_current_index_null) continue; if (i != start_insert_index) { if (result_nullable) result_nullable->insertRangeFromNotNullable(nested_column, start_insert_index, i - start_insert_index); else result->insertRangeFrom(nested_column, start_insert_index, i - start_insert_index); } if (default_column) result->insertFrom(*default_column, i); else result->insert(default_value); start_insert_index = i + 1; } if (null_map_data_size != start_insert_index) { if (result_nullable) result_nullable->insertRangeFromNotNullable(nested_column, start_insert_index, null_map_data_size - start_insert_index); else result->insertRangeFrom(nested_column, start_insert_index, null_map_data_size - start_insert_index); } return result; } private: bool keep_nullable; }; template class FunctionCastOrDefaultTyped final : public IFunction { public: static constexpr auto name = Name::name; static FunctionPtr create(ContextPtr context) { return std::make_shared(context); } explicit FunctionCastOrDefaultTyped(ContextPtr context_) : impl(context_) { } String getName() const override { return name; } private: size_t getNumberOfArguments() const override { return 0; } bool isVariadic() const override { return true; } bool useDefaultImplementationForNulls() const override { return impl.useDefaultImplementationForNulls(); } bool useDefaultImplementationForLowCardinalityColumns() const override { return impl.useDefaultImplementationForLowCardinalityColumns();} bool useDefaultImplementationForConstants() const override { return impl.useDefaultImplementationForConstants();} bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & arguments) const override { return impl.isSuitableForShortCircuitArgumentsExecution(arguments); } DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { FunctionArgumentDescriptors mandatory_args = {{"Value", nullptr, nullptr, nullptr}}; FunctionArgumentDescriptors optional_args; if constexpr (IsDataTypeDecimal) mandatory_args.push_back({"scale", &isNativeInteger, &isColumnConst, "const Integer"}); if (std::is_same_v || std::is_same_v) optional_args.push_back({"timezone", &isString, isColumnConst, "const String"}); optional_args.push_back({"default_value", nullptr, nullptr, nullptr}); validateFunctionArgumentTypes(*this, arguments, mandatory_args, optional_args); size_t additional_argument_index = 1; size_t scale = 0; std::string time_zone; if constexpr (IsDataTypeDecimal) { const auto & scale_argument = arguments[additional_argument_index]; WhichDataType scale_argument_type(scale_argument.type); if (!scale_argument_type.isNativeUInt()) { throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function {} decimal scale should have native UInt type. Actual {}", scale_argument.type->getName()); } scale = arguments[additional_argument_index].column->getUInt(0); ++additional_argument_index; } if constexpr (std::is_same_v || std::is_same_v) { if (additional_argument_index < arguments.size()) { time_zone = extractTimeZoneNameFromColumn(*arguments[additional_argument_index].column); ++additional_argument_index; } } std::shared_ptr cast_type; if constexpr (std::is_same_v) cast_type = std::make_shared(scale, time_zone); else if constexpr (IsDataTypeDecimal) cast_type = std::make_shared(Type::maxPrecision(), scale); else if constexpr (std::is_same_v || std::is_same_v) cast_type = std::make_shared(time_zone); else cast_type = std::make_shared(); ColumnWithTypeAndName type_argument = { DataTypeString().createColumnConst(1, cast_type->getName()), std::make_shared(), "" }; ColumnsWithTypeAndName arguments_with_cast_type; arguments_with_cast_type.reserve(arguments.size()); arguments_with_cast_type.emplace_back(arguments[0]); arguments_with_cast_type.emplace_back(type_argument); if (additional_argument_index < arguments.size()) { arguments_with_cast_type.emplace_back(arguments[additional_argument_index]); ++additional_argument_index; } if (additional_argument_index < arguments.size()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "{} wrong arguments size", getName()); return impl.getReturnTypeImpl(arguments_with_cast_type); } ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_size) const override { size_t additional_arguments_size = IsDataTypeDecimal + (std::is_same_v || std::is_same_v); ColumnWithTypeAndName second_argument = { DataTypeString().createColumnConst(arguments.begin()->column->size(), result_type->getName()), std::make_shared(), "" }; ColumnsWithTypeAndName arguments_with_cast_type; arguments_with_cast_type.reserve(arguments.size()); arguments_with_cast_type.emplace_back(arguments[0]); arguments_with_cast_type.emplace_back(second_argument); size_t default_column_argument = 1 + additional_arguments_size; if (default_column_argument < arguments.size()) arguments_with_cast_type.emplace_back(arguments[default_column_argument]); return impl.executeImpl(arguments_with_cast_type, result_type, input_rows_size); } FunctionCastOrDefault impl; }; struct NameToUInt8OrDefault { static constexpr auto name = "toUInt8OrDefault"; }; struct NameToUInt16OrDefault { static constexpr auto name = "toUInt16OrDefault"; }; struct NameToUInt32OrDefault { static constexpr auto name = "toUInt32OrDefault"; }; struct NameToUInt64OrDefault { static constexpr auto name = "toUInt64OrDefault"; }; struct NameToUInt256OrDefault { static constexpr auto name = "toUInt256OrDefault"; }; struct NameToInt8OrDefault { static constexpr auto name = "toInt8OrDefault"; }; struct NameToInt16OrDefault { static constexpr auto name = "toInt16OrDefault"; }; struct NameToInt32OrDefault { static constexpr auto name = "toInt32OrDefault"; }; struct NameToInt64OrDefault { static constexpr auto name = "toInt64OrDefault"; }; struct NameToInt128OrDefault { static constexpr auto name = "toInt128OrDefault"; }; struct NameToInt256OrDefault { static constexpr auto name = "toInt256OrDefault"; }; struct NameToFloat32OrDefault { static constexpr auto name = "toFloat32OrDefault"; }; struct NameToFloat64OrDefault { static constexpr auto name = "toFloat64OrDefault"; }; struct NameToDateOrDefault { static constexpr auto name = "toDateOrDefault"; }; struct NameToDate32OrDefault { static constexpr auto name = "toDate32OrDefault"; }; struct NameToDateTimeOrDefault { static constexpr auto name = "toDateTimeOrDefault"; }; struct NameToDateTime64OrDefault { static constexpr auto name = "toDateTime64OrDefault"; }; struct NameToDecimal32OrDefault { static constexpr auto name = "toDecimal32OrDefault"; }; struct NameToDecimal64OrDefault { static constexpr auto name = "toDecimal64OrDefault"; }; struct NameToDecimal128OrDefault { static constexpr auto name = "toDecimal128OrDefault"; }; struct NameToDecimal256OrDefault { static constexpr auto name = "toDecimal256OrDefault"; }; struct NameToUUIDOrDefault { static constexpr auto name = "toUUIDOrDefault"; }; using FunctionToUInt8OrDefault = FunctionCastOrDefaultTyped; using FunctionToUInt16OrDefault = FunctionCastOrDefaultTyped; using FunctionToUInt32OrDefault = FunctionCastOrDefaultTyped; using FunctionToUInt64OrDefault = FunctionCastOrDefaultTyped; using FunctionToUInt256OrDefault = FunctionCastOrDefaultTyped; using FunctionToInt8OrDefault = FunctionCastOrDefaultTyped; using FunctionToInt16OrDefault = FunctionCastOrDefaultTyped; using FunctionToInt32OrDefault = FunctionCastOrDefaultTyped; using FunctionToInt64OrDefault = FunctionCastOrDefaultTyped; using FunctionToInt128OrDefault = FunctionCastOrDefaultTyped; using FunctionToInt256OrDefault = FunctionCastOrDefaultTyped; using FunctionToFloat32OrDefault = FunctionCastOrDefaultTyped; using FunctionToFloat64OrDefault = FunctionCastOrDefaultTyped; using FunctionToDateOrDefault = FunctionCastOrDefaultTyped; using FunctionToDate32OrDefault = FunctionCastOrDefaultTyped; using FunctionToDateTimeOrDefault = FunctionCastOrDefaultTyped; using FunctionToDateTime64OrDefault = FunctionCastOrDefaultTyped; using FunctionToDecimal32OrDefault = FunctionCastOrDefaultTyped, NameToDecimal32OrDefault>; using FunctionToDecimal64OrDefault = FunctionCastOrDefaultTyped, NameToDecimal64OrDefault>; using FunctionToDecimal128OrDefault = FunctionCastOrDefaultTyped, NameToDecimal128OrDefault>; using FunctionToDecimal256OrDefault = FunctionCastOrDefaultTyped, NameToDecimal256OrDefault>; using FunctionToUUIDOrDefault = FunctionCastOrDefaultTyped; void registerFunctionCastOrDefault(FunctionFactory & factory) { factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); } }