#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int BAD_ARGUMENTS; extern const int ILLEGAL_TYPE_OF_ARGUMENT; } class FunctionCastOrDefault final : public IFunction { public: static constexpr auto name = "castOrDefault"; static FunctionPtr create(ContextPtr context) { return std::make_shared(context); } explicit FunctionCastOrDefault(ContextPtr context_) : keep_nullable(context_->getSettingsRef().cast_keep_nullable) { std::cerr << "FunctionCastOrDefault::constructor" << std::endl; } String getName() const override { return name; } size_t getNumberOfArguments() const override { return 0; } bool isVariadic() const override { return true; } ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } bool useDefaultImplementationForNulls() const override { return false; } bool useDefaultImplementationForLowCardinalityColumns() const override { return false; } DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { if (arguments.size() < 2) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} expected 2 or 3 arguments. Actual {}", getName(), arguments.size()); const auto & type_column = arguments[1].column; const auto * type_column_typed = checkAndGetColumnConst(type_column.get()); if (!type_column_typed) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument to {} must be a constant string describing type." " Instead there is a column with the following structure: {}", getName(), type_column->dumpStructure()); DataTypePtr type = DataTypeFactory::instance().get(type_column_typed->getValue()); if (keep_nullable && arguments.front().type->isNullable()) return makeNullable(type); return type; } ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type, size_t) const override { const auto & type_column = arguments[1].column; const auto * type_column_typed = checkAndGetColumnConst(type_column.get()); DataTypePtr type = DataTypeFactory::instance().get(type_column_typed->getValue()); const ColumnWithTypeAndName & column_to_cast = arguments[0]; auto non_const_column_to_cast = column_to_cast.column->convertToFullColumnIfConst(); ColumnWithTypeAndName column_to_cast_non_const { std::move(non_const_column_to_cast), column_to_cast.type, column_to_cast.name }; auto cast_result = castColumnAccurateOrNull(column_to_cast_non_const, type); const auto & cast_result_nullable = assert_cast(*cast_result); const auto & null_map_data = cast_result_nullable.getNullMapData(); const auto & nested_column = cast_result_nullable.getNestedColumn(); IColumn::MutablePtr result = type->createColumn(); result->reserve(null_map_data.size()); size_t start_insert_index = 0; /// Created separate branch because cast and inserting field from other column is slower if (arguments.size() == 3) { const auto & default_column_with_type = arguments[2]; const auto & default_column_type = default_column_with_type.type; auto default_column = default_column_with_type.column->convertToFullColumnIfConst(); ColumnWithTypeAndName default_column_to_cast_non_const { std::move(default_column), default_column_type, default_column_with_type.name }; auto default_column_casted = castColumnAccurate(default_column_to_cast_non_const, return_type); for (size_t i = 0; i < null_map_data.size(); ++i) { bool is_current_index_null = null_map_data[i]; if (!is_current_index_null) continue; result->insertRangeFrom(nested_column, start_insert_index, i - start_insert_index); result->insertFrom(*default_column_casted, i); start_insert_index = i + 1; } } else { for (size_t i = 0; i < null_map_data.size(); ++i) { bool is_current_index_null = null_map_data[i]; if (!is_current_index_null) continue; result->insertRangeFrom(nested_column, start_insert_index, i - start_insert_index); result->insertDefault(); start_insert_index = i + 1; } } result->insertRangeFrom(nested_column, start_insert_index, null_map_data.size() - start_insert_index); return result; } private: bool keep_nullable; }; template class FunctionCastOrDefaultTyped final : public IFunction { public: static constexpr auto name = Name::name; static FunctionPtr create(ContextPtr context) { return std::make_shared(context); } explicit FunctionCastOrDefaultTyped(ContextPtr context_) : impl(context_) { } String getName() const override { return name; } private: size_t getNumberOfArguments() const override { return 0; } bool isVariadic() const override { return true; } bool useDefaultImplementationForNulls() const override { return false; } bool useDefaultImplementationForLowCardinalityColumns() const override { return false; } DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { size_t additional_argument_index = 1; size_t scale = 0; std::string time_zone; if constexpr (IsDataTypeDecimal) { if (additional_argument_index < arguments.size()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "{} for decimal type requires additionae scale argument {}", getName()); const auto & scale_argument = arguments[additional_argument_index]; WhichDataType scale_argument_type(scale_argument.type); if (!scale_argument_type.isNativeUInt()) { throw Exception(ErrorCodes::BAD_ARGUMENTS, "Function {} decimal scale should have native UInt type. Actual {}", scale_argument.type->getName()); } scale = arguments[additional_argument_index].column->getUInt(0); ++additional_argument_index; } if constexpr (std::is_same_v || std::is_same_v) { if (additional_argument_index < arguments.size()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "{} for DateTime or DateTime64 requires additional timezone argument {}", getName()); time_zone = extractTimeZoneNameFromFunctionArguments(arguments, additional_argument_index, 0); ++additional_argument_index; } std::shared_ptr cast_type; if constexpr (std::is_same_v) cast_type = std::make_shared(scale, time_zone); else if constexpr (IsDataTypeDecimal) cast_type = std::make_shared(Type::maxPrecision(), scale); else if constexpr (std::is_same_v || std::is_same_v) cast_type = std::make_shared(time_zone); else cast_type = std::make_shared(); ColumnWithTypeAndName type_argument = { DataTypeString().createColumnConst(arguments.begin()->column->size(), cast_type->getName()), std::make_shared(), "" }; ColumnsWithTypeAndName arguments_with_cast_type; arguments_with_cast_type.reserve(arguments.size()); arguments_with_cast_type.emplace_back(arguments[0]); arguments_with_cast_type.emplace_back(type_argument); if (additional_argument_index < arguments.size()) { arguments_with_cast_type.emplace_back(arguments[additional_argument_index]); ++additional_argument_index; } if (additional_argument_index < arguments.size()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "{} wrong arguments size", getName()); return impl.getReturnTypeImpl(arguments_with_cast_type); } ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_size) const override { size_t additional_arguments_size = IsDataTypeDecimal + (std::is_same_v || std::is_same_v); ColumnWithTypeAndName second_argument = { DataTypeString().createColumnConst(arguments.begin()->column->size(), result_type->getName()), std::make_shared(), "" }; ColumnsWithTypeAndName arguments_with_cast_type; arguments_with_cast_type.reserve(arguments.size()); arguments_with_cast_type.emplace_back(arguments[0]); arguments_with_cast_type.emplace_back(second_argument); size_t default_column_argument = 1 + additional_arguments_size; if (default_column_argument < arguments.size()) arguments_with_cast_type.emplace_back(arguments[default_column_argument]); return impl.executeImpl(arguments_with_cast_type, result_type, input_rows_size); } FunctionCastOrDefault impl; }; struct NameToUInt8OrDefault { static constexpr auto name = "toUInt8OrDefault"; }; struct NameToUInt16OrDefault { static constexpr auto name = "toUInt16OrDefault"; }; struct NameToUInt32OrDefault { static constexpr auto name = "toUInt32OrDefault"; }; struct NameToUInt64OrDefault { static constexpr auto name = "toUInt64OrDefault"; }; struct NameToUInt256OrDefault { static constexpr auto name = "toUInt256OrDefault"; }; struct NameToInt8OrDefault { static constexpr auto name = "toInt8OrDefault"; }; struct NameToInt16OrDefault { static constexpr auto name = "toInt16OrDefault"; }; struct NameToInt32OrDefault { static constexpr auto name = "toInt32OrDefault"; }; struct NameToInt64OrDefault { static constexpr auto name = "toInt64OrDefault"; }; struct NameToInt128OrDefault { static constexpr auto name = "toInt128OrDefault"; }; struct NameToInt256OrDefault { static constexpr auto name = "toInt256OrDefault"; }; struct NameToFloat32OrDefault { static constexpr auto name = "toFloat32OrDefault"; }; struct NameToFloat64OrDefault { static constexpr auto name = "toFloat64OrDefault"; }; struct NameToDateOrDefault { static constexpr auto name = "toDateOrDefault"; }; struct NameToDateTimeOrDefault { static constexpr auto name = "toDateTimeOrDefault"; }; struct NameToDateTime64OrDefault { static constexpr auto name = "toDateTime64OrDefault"; }; struct NameToDecimal32OrDefault { static constexpr auto name = "toDecimal32OrDefault"; }; struct NameToDecimal64OrDefault { static constexpr auto name = "toDecimal64OrDefault"; }; struct NameToDecimal128OrDefault { static constexpr auto name = "toDecimal128OrDefault"; }; struct NameToDecimal256OrDefault { static constexpr auto name = "toDecimal256OrDefault"; }; struct NameToUUIDOrDefault { static constexpr auto name = "toUUIDOrDefault"; }; using FunctionToUInt8OrDefault = FunctionCastOrDefaultTyped; using FunctionToUInt16OrDefault = FunctionCastOrDefaultTyped; using FunctionToUInt32OrDefault = FunctionCastOrDefaultTyped; using FunctionToUInt64OrDefault = FunctionCastOrDefaultTyped; using FunctionToUInt256OrDefault = FunctionCastOrDefaultTyped; using FunctionToInt8OrDefault = FunctionCastOrDefaultTyped; using FunctionToInt16OrDefault = FunctionCastOrDefaultTyped; using FunctionToInt32OrDefault = FunctionCastOrDefaultTyped; using FunctionToInt64OrDefault = FunctionCastOrDefaultTyped; using FunctionToInt128OrDefault = FunctionCastOrDefaultTyped; using FunctionToInt256OrDefault = FunctionCastOrDefaultTyped; using FunctionToFloat32OrDefault = FunctionCastOrDefaultTyped; using FunctionToFloat64OrDefault = FunctionCastOrDefaultTyped; using FunctionToDateOrDefault = FunctionCastOrDefaultTyped; using FunctionToDateTimeOrDefault = FunctionCastOrDefaultTyped; using FunctionToDateTime64OrDefault = FunctionCastOrDefaultTyped; using FunctionToDecimal32OrDefault = FunctionCastOrDefaultTyped, NameToDecimal32OrDefault>; using FunctionToDecimal64OrDefault = FunctionCastOrDefaultTyped, NameToDecimal64OrDefault>; using FunctionToDecimal128OrDefault = FunctionCastOrDefaultTyped, NameToDecimal128OrDefault>; using FunctionToDecimal256OrDefault = FunctionCastOrDefaultTyped, NameToDecimal256OrDefault>; using FunctionToUUIDOrDefault = FunctionCastOrDefaultTyped; void registerFunctionCastOrDefault(FunctionFactory & factory) { factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); } }