From 5e8bc4402ab4df42d228c0474ee01fbb83c97a71 Mon Sep 17 00:00:00 2001 From: HowePa <2873679104@qq.com> Date: Thu, 25 Apr 2024 15:52:30 +0800 Subject: [PATCH] unified NumpyDataTypes --- src/Formats/NumpyDataTypes.h | 50 ++++++++-- .../Formats/Impl/NpyOutputFormat.cpp | 91 ++++++++++++------- src/Processors/Formats/Impl/NpyOutputFormat.h | 15 +-- 3 files changed, 99 insertions(+), 57 deletions(-) diff --git a/src/Formats/NumpyDataTypes.h b/src/Formats/NumpyDataTypes.h index 712797515c9..5cf2ebf5b40 100644 --- a/src/Formats/NumpyDataTypes.h +++ b/src/Formats/NumpyDataTypes.h @@ -1,10 +1,12 @@ #pragma once #include #include +#include namespace ErrorCodes { extern const int BAD_ARGUMENTS; + extern const int NOT_IMPLEMENTED; } enum class NumpyDataTypeIndex @@ -29,9 +31,9 @@ class NumpyDataType public: enum Endianness { - LITTLE, - BIG, - NONE, + LITTLE = '<', + BIG = '>', + NONE = '|', }; NumpyDataTypeIndex type_index; @@ -41,15 +43,18 @@ public: Endianness getEndianness() const { return endianness; } virtual NumpyDataTypeIndex getTypeIndex() const = 0; + virtual size_t getSize() const { throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Function getSize() is not implemented"); } + virtual void setSize(size_t) { throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Function setSize() is not implemented"); } + virtual String str() const { throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Function str() is not implemented"); } -private: +protected: Endianness endianness; }; class NumpyDataTypeInt : public NumpyDataType { public: - NumpyDataTypeInt(Endianness endianness, size_t size_, bool is_signed_) : NumpyDataType(endianness), size(size_), is_signed(is_signed_) + NumpyDataTypeInt(Endianness endianness_, size_t size_, bool is_signed_) : NumpyDataType(endianness_), size(size_), is_signed(is_signed_) { switch (size) { @@ -67,6 +72,14 @@ public: return type_index; } bool isSigned() const { return is_signed; } + String str() const override + { + DB::WriteBufferFromOwnString buf; + writeChar(static_cast(endianness), buf); + writeChar(is_signed ? 'i' : 'u', buf); + writeIntText(size, buf); + return buf.str(); + } private: size_t size; @@ -76,7 +89,7 @@ private: class NumpyDataTypeFloat : public NumpyDataType { public: - NumpyDataTypeFloat(Endianness endianness, size_t size_) : NumpyDataType(endianness), size(size_) + NumpyDataTypeFloat(Endianness endianness_, size_t size_) : NumpyDataType(endianness_), size(size_) { switch (size) { @@ -92,6 +105,14 @@ public: { return type_index; } + String str() const override + { + DB::WriteBufferFromOwnString buf; + writeChar(static_cast(endianness), buf); + writeChar('f', buf); + writeIntText(size, buf); + return buf.str(); + } private: size_t size; }; @@ -99,13 +120,22 @@ private: class NumpyDataTypeString : public NumpyDataType { public: - NumpyDataTypeString(Endianness endianness, size_t size_) : NumpyDataType(endianness), size(size_) + NumpyDataTypeString(Endianness endianness_, size_t size_) : NumpyDataType(endianness_), size(size_) { type_index = NumpyDataTypeIndex::String; } NumpyDataTypeIndex getTypeIndex() const override { return type_index; } - size_t getSize() const { return size; } + size_t getSize() const override { return size; } + void setSize(size_t size_) override { size = size_; } + String str() const override + { + DB::WriteBufferFromOwnString buf; + writeChar(static_cast(endianness), buf); + writeChar('S', buf); + writeIntText(size, buf); + return buf.str(); + } private: size_t size; }; @@ -113,13 +143,13 @@ private: class NumpyDataTypeUnicode : public NumpyDataType { public: - NumpyDataTypeUnicode(Endianness endianness, size_t size_) : NumpyDataType(endianness), size(size_) + NumpyDataTypeUnicode(Endianness endianness_, size_t size_) : NumpyDataType(endianness_), size(size_) { type_index = NumpyDataTypeIndex::Unicode; } NumpyDataTypeIndex getTypeIndex() const override { return type_index; } - size_t getSize() const { return size * 4; } + size_t getSize() const override { return size * 4; } private: size_t size; }; diff --git a/src/Processors/Formats/Impl/NpyOutputFormat.cpp b/src/Processors/Formats/Impl/NpyOutputFormat.cpp index d54fc7e68f2..64272307e9d 100644 --- a/src/Processors/Formats/Impl/NpyOutputFormat.cpp +++ b/src/Processors/Formats/Impl/NpyOutputFormat.cpp @@ -45,16 +45,6 @@ void writeNumpyStrings(const ColumnPtr & column, size_t length, WriteBuffer & bu } -String NpyOutputFormat::NumpyDataType::str() const -{ - WriteBufferFromOwnString dtype; - writeChar(endianness, dtype); - writeChar(type, dtype); - writeIntText(size, dtype); - - return dtype.str(); -} - String NpyOutputFormat::shapeStr() const { WriteBufferFromOwnString shape; @@ -85,20 +75,48 @@ bool NpyOutputFormat::getNumpyDataType(const DataTypePtr & type) { switch (type->getTypeId()) { - case TypeIndex::Int8: numpy_data_type = NumpyDataType('<', 'i', sizeof(Int8)); break; - case TypeIndex::Int16: numpy_data_type = NumpyDataType('<', 'i', sizeof(Int16)); break; - case TypeIndex::Int32: numpy_data_type = NumpyDataType('<', 'i', sizeof(Int32)); break; - case TypeIndex::Int64: numpy_data_type = NumpyDataType('<', 'i', sizeof(Int64)); break; - case TypeIndex::UInt8: numpy_data_type = NumpyDataType('<', 'u', sizeof(UInt8)); break; - case TypeIndex::UInt16: numpy_data_type = NumpyDataType('<', 'u', sizeof(UInt16)); break; - case TypeIndex::UInt32: numpy_data_type = NumpyDataType('<', 'u', sizeof(UInt32)); break; - case TypeIndex::UInt64: numpy_data_type = NumpyDataType('<', 'u', sizeof(UInt64)); break; - case TypeIndex::Float32: numpy_data_type = NumpyDataType('<', 'f', sizeof(Float32)); break; - case TypeIndex::Float64: numpy_data_type = NumpyDataType('<', 'f', sizeof(Float64)); break; - case TypeIndex::FixedString: numpy_data_type = NumpyDataType('|', 'S', assert_cast(type.get())->getN()); break; - case TypeIndex::String: numpy_data_type = NumpyDataType('|', 'S', 0); break; - case TypeIndex::Array: return getNumpyDataType(assert_cast(type.get())->getNestedType()); - default: nested_data_type = type; return false; + case TypeIndex::Int8: + numpy_data_type = std::make_shared(NumpyDataType::Endianness::LITTLE, sizeof(Int8), true); + break; + case TypeIndex::Int16: + numpy_data_type = std::make_shared(NumpyDataType::Endianness::LITTLE, sizeof(Int16), true); + break; + case TypeIndex::Int32: + numpy_data_type = std::make_shared(NumpyDataType::Endianness::LITTLE, sizeof(Int32), true); + break; + case TypeIndex::Int64: + numpy_data_type = std::make_shared(NumpyDataType::Endianness::LITTLE, sizeof(Int64), true); + break; + case TypeIndex::UInt8: + numpy_data_type = std::make_shared(NumpyDataType::Endianness::LITTLE, sizeof(UInt8), false); + break; + case TypeIndex::UInt16: + numpy_data_type = std::make_shared(NumpyDataType::Endianness::LITTLE, sizeof(UInt16), false); + break; + case TypeIndex::UInt32: + numpy_data_type = std::make_shared(NumpyDataType::Endianness::LITTLE, sizeof(UInt32), false); + break; + case TypeIndex::UInt64: + numpy_data_type = std::make_shared(NumpyDataType::Endianness::LITTLE, sizeof(UInt64), false); + break; + case TypeIndex::Float32: + numpy_data_type = std::make_shared(NumpyDataType::Endianness::LITTLE, sizeof(Float32)); + break; + case TypeIndex::Float64: + numpy_data_type = std::make_shared(NumpyDataType::Endianness::LITTLE, sizeof(Float64)); + break; + case TypeIndex::FixedString: + numpy_data_type = std::make_shared( + NumpyDataType::Endianness::NONE, assert_cast(type.get())->getN()); + break; + case TypeIndex::String: + numpy_data_type = std::make_shared(NumpyDataType::Endianness::NONE, 0); + break; + case TypeIndex::Array: + return getNumpyDataType(assert_cast(type.get())->getNestedType()); + default: + nested_data_type = type; + return false; } nested_data_type = type; @@ -117,6 +135,9 @@ void NpyOutputFormat::consume(Chunk chunk) initShape(column); is_initialized = true; } + // ColumnPtr checkShape, if nullptr? + // updateSizeIfTypeString + // columns.push_back() if (!checkShape(column)) { @@ -130,13 +151,9 @@ void NpyOutputFormat::initShape(const ColumnPtr & column) { auto type = data_type; ColumnPtr nested_column = column; - while (type->getTypeId() == TypeIndex::Array) + while (const auto * array_column = typeid_cast(nested_column.get())) { - const auto * array_column = assert_cast(nested_column.get()); - numpy_shape.push_back(array_column->getOffsets()[0]); - - type = assert_cast(type.get())->getNestedType(); nested_column = array_column->getDataPtr(); } } @@ -166,7 +183,8 @@ bool NpyOutputFormat::checkShape(const ColumnPtr & column) for (size_t i = 0; i < string_offsets.size(); ++i) { size_t string_length = static_cast(string_offsets[i] - 1 - string_offsets[i - 1]); - numpy_data_type.size = numpy_data_type.size > string_length ? numpy_data_type.size : string_length; + if (numpy_data_type->getSize() < string_length) + numpy_data_type->setSize(string_length); } } @@ -185,7 +203,7 @@ void NpyOutputFormat::finalizeImpl() void NpyOutputFormat::writeHeader() { - String dict = "{'descr':'" + numpy_data_type.str() + "','fortran_order':False,'shape':(" + shapeStr() + "),}"; + String dict = "{'descr':'" + numpy_data_type->str() + "','fortran_order':False,'shape':(" + shapeStr() + "),}"; String padding = "\n"; /// completes the length of the header, which is divisible by 64. @@ -221,9 +239,14 @@ void NpyOutputFormat::writeColumns() case TypeIndex::UInt64: writeNumpyNumbers(column, out); break; case TypeIndex::Float32: writeNumpyNumbers(column, out); break; case TypeIndex::Float64: writeNumpyNumbers(column, out); break; - case TypeIndex::FixedString: writeNumpyStrings(column, numpy_data_type.size, out); break; - case TypeIndex::String: writeNumpyStrings(column, numpy_data_type.size, out); break; - default: break; + case TypeIndex::FixedString: + writeNumpyStrings(column, numpy_data_type->getSize(), out); + break; + case TypeIndex::String: + writeNumpyStrings(column, numpy_data_type->getSize(), out); + break; + default: + break; } } } diff --git a/src/Processors/Formats/Impl/NpyOutputFormat.h b/src/Processors/Formats/Impl/NpyOutputFormat.h index 83fad657b2e..6859cf10e69 100644 --- a/src/Processors/Formats/Impl/NpyOutputFormat.h +++ b/src/Processors/Formats/Impl/NpyOutputFormat.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -28,18 +29,6 @@ public: String getContentType() const override { return "application/octet-stream"; } private: - struct NumpyDataType - { - char endianness; - char type; - size_t size; - - NumpyDataType() = default; - NumpyDataType(char endianness_, char type_, size_t size_) - : endianness(endianness_), type(type_), size(size_) {} - String str() const; - }; - String shapeStr() const; bool getNumpyDataType(const DataTypePtr & type); @@ -57,7 +46,7 @@ private: DataTypePtr data_type; DataTypePtr nested_data_type; - NumpyDataType numpy_data_type; + std::shared_ptr numpy_data_type; UInt64 num_rows = 0; std::vector numpy_shape; Columns columns;