diff --git a/src/Functions/registerFunctionsReinterpret.cpp b/src/Functions/registerFunctionsReinterpret.cpp index d2e43fbd52a..d82274ce9ed 100644 --- a/src/Functions/registerFunctionsReinterpret.cpp +++ b/src/Functions/registerFunctionsReinterpret.cpp @@ -3,13 +3,13 @@ namespace DB class FunctionFactory; -void registerFunctionsReinterpretStringAs(FunctionFactory & factory); +void registerFunctionsReinterpretAs(FunctionFactory & factory); void registerFunctionReinterpretAsString(FunctionFactory & factory); void registerFunctionReinterpretAsFixedString(FunctionFactory & factory); void registerFunctionsReinterpret(FunctionFactory & factory) { - registerFunctionsReinterpretStringAs(factory); + registerFunctionsReinterpretAs(factory); registerFunctionReinterpretAsString(factory); registerFunctionReinterpretAsFixedString(factory); } diff --git a/src/Functions/reinterpretStringAs.cpp b/src/Functions/reinterpretAs.cpp similarity index 54% rename from src/Functions/reinterpretStringAs.cpp rename to src/Functions/reinterpretAs.cpp index cd36d63dd46..1efadd4f491 100644 --- a/src/Functions/reinterpretStringAs.cpp +++ b/src/Functions/reinterpretAs.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -25,30 +26,60 @@ namespace ErrorCodes namespace { - -template -class FunctionReinterpretStringAs : public IFunction +template +class FunctionReinterpretAs : public IFunction { + template + static bool castType(const IDataType * type, F && f) + { + return castTypeToEither( + type, std::forward(f)); + } + + template + static void reinterpretImpl(const PaddedPODArray & from, PaddedPODArray & to) + { + const auto * from_reinterpret = reinterpret_cast(const_cast(from.data())); + to.resize(from.size()); + for (size_t i = 0; i < from.size(); ++i) + { + to[i] = from_reinterpret[i]; + } + } + public: static constexpr auto name = Name::name; - static FunctionPtr create(const Context &) { return std::make_shared(); } + static FunctionPtr create(const Context &) { return std::make_shared(); } using ToFieldType = typename ToDataType::FieldType; using ColumnType = typename ToDataType::ColumnType; - String getName() const override - { - return name; - } + String getName() const override { return name; } size_t getNumberOfArguments() const override { return 1; } DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { const IDataType & type = *arguments[0]; - if (!isStringOrFixedString(type)) - throw Exception("Cannot reinterpret " + type.getName() + " as " + ToDataType().getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if constexpr (support_between_float_integer) + { + if (!isStringOrFixedString(type) && !isNumber(type)) + throw Exception( + "Cannot reinterpret " + type.getName() + " as " + ToDataType().getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if (isNumber(type)) + { + if (type.getSizeOfValueInMemory() != ToDataType{}.getSizeOfValueInMemory()) + throw Exception( + "Cannot reinterpret " + type.getName() + " as " + ToDataType().getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + } + else + { + if (!isStringOrFixedString(type)) + throw Exception( + "Cannot reinterpret " + type.getName() + " as " + ToDataType().getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } return std::make_shared(); } @@ -99,10 +130,34 @@ public: return col_res; } + else if constexpr (support_between_float_integer) + { + ColumnPtr res; + if (castType(arguments[0].type.get(), [&](const auto & type) { + using DataType = std::decay_t; + using T = typename DataType::FieldType; + + const ColumnVector * col = checkAndGetColumn>(arguments[0].column.get()); + auto col_res = ColumnType::create(); + reinterpretImpl(col->getData(), col_res->getData()); + res = std::move(col_res); + + return true; + })) + { + return res; + } + else + { + throw Exception( + "Illegal column " + arguments[0].column->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_COLUMN); + } + } else { - throw Exception("Illegal column " + arguments[0].column->getName() - + " of argument of function " + getName(), + throw Exception( + "Illegal column " + arguments[0].column->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN); } } @@ -123,23 +178,22 @@ struct NameReinterpretAsDate { static constexpr auto name = "reinterpretA struct NameReinterpretAsDateTime { static constexpr auto name = "reinterpretAsDateTime"; }; struct NameReinterpretAsUUID { static constexpr auto name = "reinterpretAsUUID"; }; -using FunctionReinterpretAsUInt8 = FunctionReinterpretStringAs; -using FunctionReinterpretAsUInt16 = FunctionReinterpretStringAs; -using FunctionReinterpretAsUInt32 = FunctionReinterpretStringAs; -using FunctionReinterpretAsUInt64 = FunctionReinterpretStringAs; -using FunctionReinterpretAsInt8 = FunctionReinterpretStringAs; -using FunctionReinterpretAsInt16 = FunctionReinterpretStringAs; -using FunctionReinterpretAsInt32 = FunctionReinterpretStringAs; -using FunctionReinterpretAsInt64 = FunctionReinterpretStringAs; -using FunctionReinterpretAsFloat32 = FunctionReinterpretStringAs; -using FunctionReinterpretAsFloat64 = FunctionReinterpretStringAs; -using FunctionReinterpretAsDate = FunctionReinterpretStringAs; -using FunctionReinterpretAsDateTime = FunctionReinterpretStringAs; -using FunctionReinterpretAsUUID = FunctionReinterpretStringAs; - +using FunctionReinterpretAsUInt8 = FunctionReinterpretAs; +using FunctionReinterpretAsUInt16 = FunctionReinterpretAs; +using FunctionReinterpretAsUInt32 = FunctionReinterpretAs; +using FunctionReinterpretAsUInt64 = FunctionReinterpretAs; +using FunctionReinterpretAsInt8 = FunctionReinterpretAs; +using FunctionReinterpretAsInt16 = FunctionReinterpretAs; +using FunctionReinterpretAsInt32 = FunctionReinterpretAs; +using FunctionReinterpretAsInt64 = FunctionReinterpretAs; +using FunctionReinterpretAsFloat32 = FunctionReinterpretAs; +using FunctionReinterpretAsFloat64 = FunctionReinterpretAs; +using FunctionReinterpretAsDate = FunctionReinterpretAs; +using FunctionReinterpretAsDateTime = FunctionReinterpretAs; +using FunctionReinterpretAsUUID = FunctionReinterpretAs; } -void registerFunctionsReinterpretStringAs(FunctionFactory & factory) +void registerFunctionsReinterpretAs(FunctionFactory & factory) { factory.registerFunction(); factory.registerFunction(); diff --git a/src/Functions/ya.make b/src/Functions/ya.make index 4c2cbaf5b1f..bc5af88b6af 100644 --- a/src/Functions/ya.make +++ b/src/Functions/ya.make @@ -359,7 +359,7 @@ SRCS( registerFunctionsVisitParam.cpp reinterpretAsFixedString.cpp reinterpretAsString.cpp - reinterpretStringAs.cpp + reinterpretAs.cpp repeat.cpp replaceAll.cpp replaceOne.cpp