unified NumpyDataTypes

This commit is contained in:
HowePa 2024-04-25 15:52:30 +08:00
parent ae17941e63
commit 5e8bc4402a
3 changed files with 99 additions and 57 deletions

View File

@ -1,10 +1,12 @@
#pragma once
#include <cstddef>
#include <Storages/NamedCollectionsHelpers.h>
#include <IO/WriteBufferFromString.h>
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int NOT_IMPLEMENTED;
}
enum class NumpyDataTypeIndex
@ -29,9 +31,9 @@ class NumpyDataType
public:
enum Endianness
{
LITTLE,
BIG,
NONE,
LITTLE = '<',
BIG = '>',
NONE = '|',
};
NumpyDataTypeIndex type_index;
@ -41,15 +43,18 @@ public:
Endianness getEndianness() const { return endianness; }
virtual NumpyDataTypeIndex getTypeIndex() const = 0;
virtual size_t getSize() const { throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Function getSize() is not implemented"); }
virtual void setSize(size_t) { throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Function setSize() is not implemented"); }
virtual String str() const { throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Function str() is not implemented"); }
private:
protected:
Endianness endianness;
};
class NumpyDataTypeInt : public NumpyDataType
{
public:
NumpyDataTypeInt(Endianness endianness, size_t size_, bool is_signed_) : NumpyDataType(endianness), size(size_), is_signed(is_signed_)
NumpyDataTypeInt(Endianness endianness_, size_t size_, bool is_signed_) : NumpyDataType(endianness_), size(size_), is_signed(is_signed_)
{
switch (size)
{
@ -67,6 +72,14 @@ public:
return type_index;
}
bool isSigned() const { return is_signed; }
String str() const override
{
DB::WriteBufferFromOwnString buf;
writeChar(static_cast<char>(endianness), buf);
writeChar(is_signed ? 'i' : 'u', buf);
writeIntText(size, buf);
return buf.str();
}
private:
size_t size;
@ -76,7 +89,7 @@ private:
class NumpyDataTypeFloat : public NumpyDataType
{
public:
NumpyDataTypeFloat(Endianness endianness, size_t size_) : NumpyDataType(endianness), size(size_)
NumpyDataTypeFloat(Endianness endianness_, size_t size_) : NumpyDataType(endianness_), size(size_)
{
switch (size)
{
@ -92,6 +105,14 @@ public:
{
return type_index;
}
String str() const override
{
DB::WriteBufferFromOwnString buf;
writeChar(static_cast<char>(endianness), buf);
writeChar('f', buf);
writeIntText(size, buf);
return buf.str();
}
private:
size_t size;
};
@ -99,13 +120,22 @@ private:
class NumpyDataTypeString : public NumpyDataType
{
public:
NumpyDataTypeString(Endianness endianness, size_t size_) : NumpyDataType(endianness), size(size_)
NumpyDataTypeString(Endianness endianness_, size_t size_) : NumpyDataType(endianness_), size(size_)
{
type_index = NumpyDataTypeIndex::String;
}
NumpyDataTypeIndex getTypeIndex() const override { return type_index; }
size_t getSize() const { return size; }
size_t getSize() const override { return size; }
void setSize(size_t size_) override { size = size_; }
String str() const override
{
DB::WriteBufferFromOwnString buf;
writeChar(static_cast<char>(endianness), buf);
writeChar('S', buf);
writeIntText(size, buf);
return buf.str();
}
private:
size_t size;
};
@ -113,13 +143,13 @@ private:
class NumpyDataTypeUnicode : public NumpyDataType
{
public:
NumpyDataTypeUnicode(Endianness endianness, size_t size_) : NumpyDataType(endianness), size(size_)
NumpyDataTypeUnicode(Endianness endianness_, size_t size_) : NumpyDataType(endianness_), size(size_)
{
type_index = NumpyDataTypeIndex::Unicode;
}
NumpyDataTypeIndex getTypeIndex() const override { return type_index; }
size_t getSize() const { return size * 4; }
size_t getSize() const override { return size * 4; }
private:
size_t size;
};

View File

@ -45,16 +45,6 @@ void writeNumpyStrings(const ColumnPtr & column, size_t length, WriteBuffer & bu
}
String NpyOutputFormat::NumpyDataType::str() const
{
WriteBufferFromOwnString dtype;
writeChar(endianness, dtype);
writeChar(type, dtype);
writeIntText(size, dtype);
return dtype.str();
}
String NpyOutputFormat::shapeStr() const
{
WriteBufferFromOwnString shape;
@ -85,20 +75,48 @@ bool NpyOutputFormat::getNumpyDataType(const DataTypePtr & type)
{
switch (type->getTypeId())
{
case TypeIndex::Int8: numpy_data_type = NumpyDataType('<', 'i', sizeof(Int8)); break;
case TypeIndex::Int16: numpy_data_type = NumpyDataType('<', 'i', sizeof(Int16)); break;
case TypeIndex::Int32: numpy_data_type = NumpyDataType('<', 'i', sizeof(Int32)); break;
case TypeIndex::Int64: numpy_data_type = NumpyDataType('<', 'i', sizeof(Int64)); break;
case TypeIndex::UInt8: numpy_data_type = NumpyDataType('<', 'u', sizeof(UInt8)); break;
case TypeIndex::UInt16: numpy_data_type = NumpyDataType('<', 'u', sizeof(UInt16)); break;
case TypeIndex::UInt32: numpy_data_type = NumpyDataType('<', 'u', sizeof(UInt32)); break;
case TypeIndex::UInt64: numpy_data_type = NumpyDataType('<', 'u', sizeof(UInt64)); break;
case TypeIndex::Float32: numpy_data_type = NumpyDataType('<', 'f', sizeof(Float32)); break;
case TypeIndex::Float64: numpy_data_type = NumpyDataType('<', 'f', sizeof(Float64)); break;
case TypeIndex::FixedString: numpy_data_type = NumpyDataType('|', 'S', assert_cast<const DataTypeFixedString *>(type.get())->getN()); break;
case TypeIndex::String: numpy_data_type = NumpyDataType('|', 'S', 0); break;
case TypeIndex::Array: return getNumpyDataType(assert_cast<const DataTypeArray *>(type.get())->getNestedType());
default: nested_data_type = type; return false;
case TypeIndex::Int8:
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(Int8), true);
break;
case TypeIndex::Int16:
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(Int16), true);
break;
case TypeIndex::Int32:
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(Int32), true);
break;
case TypeIndex::Int64:
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(Int64), true);
break;
case TypeIndex::UInt8:
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(UInt8), false);
break;
case TypeIndex::UInt16:
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(UInt16), false);
break;
case TypeIndex::UInt32:
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(UInt32), false);
break;
case TypeIndex::UInt64:
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(UInt64), false);
break;
case TypeIndex::Float32:
numpy_data_type = std::make_shared<NumpyDataTypeFloat>(NumpyDataType::Endianness::LITTLE, sizeof(Float32));
break;
case TypeIndex::Float64:
numpy_data_type = std::make_shared<NumpyDataTypeFloat>(NumpyDataType::Endianness::LITTLE, sizeof(Float64));
break;
case TypeIndex::FixedString:
numpy_data_type = std::make_shared<NumpyDataTypeString>(
NumpyDataType::Endianness::NONE, assert_cast<const DataTypeFixedString *>(type.get())->getN());
break;
case TypeIndex::String:
numpy_data_type = std::make_shared<NumpyDataTypeString>(NumpyDataType::Endianness::NONE, 0);
break;
case TypeIndex::Array:
return getNumpyDataType(assert_cast<const DataTypeArray *>(type.get())->getNestedType());
default:
nested_data_type = type;
return false;
}
nested_data_type = type;
@ -117,6 +135,9 @@ void NpyOutputFormat::consume(Chunk chunk)
initShape(column);
is_initialized = true;
}
// ColumnPtr checkShape, if nullptr?
// updateSizeIfTypeString
// columns.push_back()
if (!checkShape(column))
{
@ -130,13 +151,9 @@ void NpyOutputFormat::initShape(const ColumnPtr & column)
{
auto type = data_type;
ColumnPtr nested_column = column;
while (type->getTypeId() == TypeIndex::Array)
while (const auto * array_column = typeid_cast<const ColumnArray *>(nested_column.get()))
{
const auto * array_column = assert_cast<const ColumnArray *>(nested_column.get());
numpy_shape.push_back(array_column->getOffsets()[0]);
type = assert_cast<const DataTypeArray *>(type.get())->getNestedType();
nested_column = array_column->getDataPtr();
}
}
@ -166,7 +183,8 @@ bool NpyOutputFormat::checkShape(const ColumnPtr & column)
for (size_t i = 0; i < string_offsets.size(); ++i)
{
size_t string_length = static_cast<size_t>(string_offsets[i] - 1 - string_offsets[i - 1]);
numpy_data_type.size = numpy_data_type.size > string_length ? numpy_data_type.size : string_length;
if (numpy_data_type->getSize() < string_length)
numpy_data_type->setSize(string_length);
}
}
@ -185,7 +203,7 @@ void NpyOutputFormat::finalizeImpl()
void NpyOutputFormat::writeHeader()
{
String dict = "{'descr':'" + numpy_data_type.str() + "','fortran_order':False,'shape':(" + shapeStr() + "),}";
String dict = "{'descr':'" + numpy_data_type->str() + "','fortran_order':False,'shape':(" + shapeStr() + "),}";
String padding = "\n";
/// completes the length of the header, which is divisible by 64.
@ -221,9 +239,14 @@ void NpyOutputFormat::writeColumns()
case TypeIndex::UInt64: writeNumpyNumbers<ColumnUInt64, UInt64>(column, out); break;
case TypeIndex::Float32: writeNumpyNumbers<ColumnFloat32, Float32>(column, out); break;
case TypeIndex::Float64: writeNumpyNumbers<ColumnFloat64, Float64>(column, out); break;
case TypeIndex::FixedString: writeNumpyStrings<ColumnFixedString>(column, numpy_data_type.size, out); break;
case TypeIndex::String: writeNumpyStrings<ColumnString>(column, numpy_data_type.size, out); break;
default: break;
case TypeIndex::FixedString:
writeNumpyStrings<ColumnFixedString>(column, numpy_data_type->getSize(), out);
break;
case TypeIndex::String:
writeNumpyStrings<ColumnString>(column, numpy_data_type->getSize(), out);
break;
default:
break;
}
}
}

View File

@ -5,6 +5,7 @@
#include <IO/WriteBufferFromVector.h>
#include <Processors/Formats/IRowOutputFormat.h>
#include <Formats/FormatSettings.h>
#include <Formats/NumpyDataTypes.h>
#include <Columns/IColumn.h>
#include <Common/PODArray_fwd.h>
@ -28,18 +29,6 @@ public:
String getContentType() const override { return "application/octet-stream"; }
private:
struct NumpyDataType
{
char endianness;
char type;
size_t size;
NumpyDataType() = default;
NumpyDataType(char endianness_, char type_, size_t size_)
: endianness(endianness_), type(type_), size(size_) {}
String str() const;
};
String shapeStr() const;
bool getNumpyDataType(const DataTypePtr & type);
@ -57,7 +46,7 @@ private:
DataTypePtr data_type;
DataTypePtr nested_data_type;
NumpyDataType numpy_data_type;
std::shared_ptr<NumpyDataType> numpy_data_type;
UInt64 num_rows = 0;
std::vector<UInt64> numpy_shape;
Columns columns;