mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-18 20:32:43 +00:00
unified NumpyDataTypes
This commit is contained in:
parent
ae17941e63
commit
5e8bc4402a
@ -1,10 +1,12 @@
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
#include <Storages/NamedCollectionsHelpers.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int NOT_IMPLEMENTED;
|
||||
}
|
||||
|
||||
enum class NumpyDataTypeIndex
|
||||
@ -29,9 +31,9 @@ class NumpyDataType
|
||||
public:
|
||||
enum Endianness
|
||||
{
|
||||
LITTLE,
|
||||
BIG,
|
||||
NONE,
|
||||
LITTLE = '<',
|
||||
BIG = '>',
|
||||
NONE = '|',
|
||||
};
|
||||
NumpyDataTypeIndex type_index;
|
||||
|
||||
@ -41,15 +43,18 @@ public:
|
||||
Endianness getEndianness() const { return endianness; }
|
||||
|
||||
virtual NumpyDataTypeIndex getTypeIndex() const = 0;
|
||||
virtual size_t getSize() const { throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Function getSize() is not implemented"); }
|
||||
virtual void setSize(size_t) { throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Function setSize() is not implemented"); }
|
||||
virtual String str() const { throw DB::Exception(DB::ErrorCodes::NOT_IMPLEMENTED, "Function str() is not implemented"); }
|
||||
|
||||
private:
|
||||
protected:
|
||||
Endianness endianness;
|
||||
};
|
||||
|
||||
class NumpyDataTypeInt : public NumpyDataType
|
||||
{
|
||||
public:
|
||||
NumpyDataTypeInt(Endianness endianness, size_t size_, bool is_signed_) : NumpyDataType(endianness), size(size_), is_signed(is_signed_)
|
||||
NumpyDataTypeInt(Endianness endianness_, size_t size_, bool is_signed_) : NumpyDataType(endianness_), size(size_), is_signed(is_signed_)
|
||||
{
|
||||
switch (size)
|
||||
{
|
||||
@ -67,6 +72,14 @@ public:
|
||||
return type_index;
|
||||
}
|
||||
bool isSigned() const { return is_signed; }
|
||||
String str() const override
|
||||
{
|
||||
DB::WriteBufferFromOwnString buf;
|
||||
writeChar(static_cast<char>(endianness), buf);
|
||||
writeChar(is_signed ? 'i' : 'u', buf);
|
||||
writeIntText(size, buf);
|
||||
return buf.str();
|
||||
}
|
||||
|
||||
private:
|
||||
size_t size;
|
||||
@ -76,7 +89,7 @@ private:
|
||||
class NumpyDataTypeFloat : public NumpyDataType
|
||||
{
|
||||
public:
|
||||
NumpyDataTypeFloat(Endianness endianness, size_t size_) : NumpyDataType(endianness), size(size_)
|
||||
NumpyDataTypeFloat(Endianness endianness_, size_t size_) : NumpyDataType(endianness_), size(size_)
|
||||
{
|
||||
switch (size)
|
||||
{
|
||||
@ -92,6 +105,14 @@ public:
|
||||
{
|
||||
return type_index;
|
||||
}
|
||||
String str() const override
|
||||
{
|
||||
DB::WriteBufferFromOwnString buf;
|
||||
writeChar(static_cast<char>(endianness), buf);
|
||||
writeChar('f', buf);
|
||||
writeIntText(size, buf);
|
||||
return buf.str();
|
||||
}
|
||||
private:
|
||||
size_t size;
|
||||
};
|
||||
@ -99,13 +120,22 @@ private:
|
||||
class NumpyDataTypeString : public NumpyDataType
|
||||
{
|
||||
public:
|
||||
NumpyDataTypeString(Endianness endianness, size_t size_) : NumpyDataType(endianness), size(size_)
|
||||
NumpyDataTypeString(Endianness endianness_, size_t size_) : NumpyDataType(endianness_), size(size_)
|
||||
{
|
||||
type_index = NumpyDataTypeIndex::String;
|
||||
}
|
||||
|
||||
NumpyDataTypeIndex getTypeIndex() const override { return type_index; }
|
||||
size_t getSize() const { return size; }
|
||||
size_t getSize() const override { return size; }
|
||||
void setSize(size_t size_) override { size = size_; }
|
||||
String str() const override
|
||||
{
|
||||
DB::WriteBufferFromOwnString buf;
|
||||
writeChar(static_cast<char>(endianness), buf);
|
||||
writeChar('S', buf);
|
||||
writeIntText(size, buf);
|
||||
return buf.str();
|
||||
}
|
||||
private:
|
||||
size_t size;
|
||||
};
|
||||
@ -113,13 +143,13 @@ private:
|
||||
class NumpyDataTypeUnicode : public NumpyDataType
|
||||
{
|
||||
public:
|
||||
NumpyDataTypeUnicode(Endianness endianness, size_t size_) : NumpyDataType(endianness), size(size_)
|
||||
NumpyDataTypeUnicode(Endianness endianness_, size_t size_) : NumpyDataType(endianness_), size(size_)
|
||||
{
|
||||
type_index = NumpyDataTypeIndex::Unicode;
|
||||
}
|
||||
|
||||
NumpyDataTypeIndex getTypeIndex() const override { return type_index; }
|
||||
size_t getSize() const { return size * 4; }
|
||||
size_t getSize() const override { return size * 4; }
|
||||
private:
|
||||
size_t size;
|
||||
};
|
||||
|
@ -45,16 +45,6 @@ void writeNumpyStrings(const ColumnPtr & column, size_t length, WriteBuffer & bu
|
||||
|
||||
}
|
||||
|
||||
String NpyOutputFormat::NumpyDataType::str() const
|
||||
{
|
||||
WriteBufferFromOwnString dtype;
|
||||
writeChar(endianness, dtype);
|
||||
writeChar(type, dtype);
|
||||
writeIntText(size, dtype);
|
||||
|
||||
return dtype.str();
|
||||
}
|
||||
|
||||
String NpyOutputFormat::shapeStr() const
|
||||
{
|
||||
WriteBufferFromOwnString shape;
|
||||
@ -85,20 +75,48 @@ bool NpyOutputFormat::getNumpyDataType(const DataTypePtr & type)
|
||||
{
|
||||
switch (type->getTypeId())
|
||||
{
|
||||
case TypeIndex::Int8: numpy_data_type = NumpyDataType('<', 'i', sizeof(Int8)); break;
|
||||
case TypeIndex::Int16: numpy_data_type = NumpyDataType('<', 'i', sizeof(Int16)); break;
|
||||
case TypeIndex::Int32: numpy_data_type = NumpyDataType('<', 'i', sizeof(Int32)); break;
|
||||
case TypeIndex::Int64: numpy_data_type = NumpyDataType('<', 'i', sizeof(Int64)); break;
|
||||
case TypeIndex::UInt8: numpy_data_type = NumpyDataType('<', 'u', sizeof(UInt8)); break;
|
||||
case TypeIndex::UInt16: numpy_data_type = NumpyDataType('<', 'u', sizeof(UInt16)); break;
|
||||
case TypeIndex::UInt32: numpy_data_type = NumpyDataType('<', 'u', sizeof(UInt32)); break;
|
||||
case TypeIndex::UInt64: numpy_data_type = NumpyDataType('<', 'u', sizeof(UInt64)); break;
|
||||
case TypeIndex::Float32: numpy_data_type = NumpyDataType('<', 'f', sizeof(Float32)); break;
|
||||
case TypeIndex::Float64: numpy_data_type = NumpyDataType('<', 'f', sizeof(Float64)); break;
|
||||
case TypeIndex::FixedString: numpy_data_type = NumpyDataType('|', 'S', assert_cast<const DataTypeFixedString *>(type.get())->getN()); break;
|
||||
case TypeIndex::String: numpy_data_type = NumpyDataType('|', 'S', 0); break;
|
||||
case TypeIndex::Array: return getNumpyDataType(assert_cast<const DataTypeArray *>(type.get())->getNestedType());
|
||||
default: nested_data_type = type; return false;
|
||||
case TypeIndex::Int8:
|
||||
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(Int8), true);
|
||||
break;
|
||||
case TypeIndex::Int16:
|
||||
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(Int16), true);
|
||||
break;
|
||||
case TypeIndex::Int32:
|
||||
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(Int32), true);
|
||||
break;
|
||||
case TypeIndex::Int64:
|
||||
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(Int64), true);
|
||||
break;
|
||||
case TypeIndex::UInt8:
|
||||
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(UInt8), false);
|
||||
break;
|
||||
case TypeIndex::UInt16:
|
||||
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(UInt16), false);
|
||||
break;
|
||||
case TypeIndex::UInt32:
|
||||
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(UInt32), false);
|
||||
break;
|
||||
case TypeIndex::UInt64:
|
||||
numpy_data_type = std::make_shared<NumpyDataTypeInt>(NumpyDataType::Endianness::LITTLE, sizeof(UInt64), false);
|
||||
break;
|
||||
case TypeIndex::Float32:
|
||||
numpy_data_type = std::make_shared<NumpyDataTypeFloat>(NumpyDataType::Endianness::LITTLE, sizeof(Float32));
|
||||
break;
|
||||
case TypeIndex::Float64:
|
||||
numpy_data_type = std::make_shared<NumpyDataTypeFloat>(NumpyDataType::Endianness::LITTLE, sizeof(Float64));
|
||||
break;
|
||||
case TypeIndex::FixedString:
|
||||
numpy_data_type = std::make_shared<NumpyDataTypeString>(
|
||||
NumpyDataType::Endianness::NONE, assert_cast<const DataTypeFixedString *>(type.get())->getN());
|
||||
break;
|
||||
case TypeIndex::String:
|
||||
numpy_data_type = std::make_shared<NumpyDataTypeString>(NumpyDataType::Endianness::NONE, 0);
|
||||
break;
|
||||
case TypeIndex::Array:
|
||||
return getNumpyDataType(assert_cast<const DataTypeArray *>(type.get())->getNestedType());
|
||||
default:
|
||||
nested_data_type = type;
|
||||
return false;
|
||||
}
|
||||
|
||||
nested_data_type = type;
|
||||
@ -117,6 +135,9 @@ void NpyOutputFormat::consume(Chunk chunk)
|
||||
initShape(column);
|
||||
is_initialized = true;
|
||||
}
|
||||
// ColumnPtr checkShape, if nullptr?
|
||||
// updateSizeIfTypeString
|
||||
// columns.push_back()
|
||||
|
||||
if (!checkShape(column))
|
||||
{
|
||||
@ -130,13 +151,9 @@ void NpyOutputFormat::initShape(const ColumnPtr & column)
|
||||
{
|
||||
auto type = data_type;
|
||||
ColumnPtr nested_column = column;
|
||||
while (type->getTypeId() == TypeIndex::Array)
|
||||
while (const auto * array_column = typeid_cast<const ColumnArray *>(nested_column.get()))
|
||||
{
|
||||
const auto * array_column = assert_cast<const ColumnArray *>(nested_column.get());
|
||||
|
||||
numpy_shape.push_back(array_column->getOffsets()[0]);
|
||||
|
||||
type = assert_cast<const DataTypeArray *>(type.get())->getNestedType();
|
||||
nested_column = array_column->getDataPtr();
|
||||
}
|
||||
}
|
||||
@ -166,7 +183,8 @@ bool NpyOutputFormat::checkShape(const ColumnPtr & column)
|
||||
for (size_t i = 0; i < string_offsets.size(); ++i)
|
||||
{
|
||||
size_t string_length = static_cast<size_t>(string_offsets[i] - 1 - string_offsets[i - 1]);
|
||||
numpy_data_type.size = numpy_data_type.size > string_length ? numpy_data_type.size : string_length;
|
||||
if (numpy_data_type->getSize() < string_length)
|
||||
numpy_data_type->setSize(string_length);
|
||||
}
|
||||
}
|
||||
|
||||
@ -185,7 +203,7 @@ void NpyOutputFormat::finalizeImpl()
|
||||
|
||||
void NpyOutputFormat::writeHeader()
|
||||
{
|
||||
String dict = "{'descr':'" + numpy_data_type.str() + "','fortran_order':False,'shape':(" + shapeStr() + "),}";
|
||||
String dict = "{'descr':'" + numpy_data_type->str() + "','fortran_order':False,'shape':(" + shapeStr() + "),}";
|
||||
String padding = "\n";
|
||||
|
||||
/// completes the length of the header, which is divisible by 64.
|
||||
@ -221,9 +239,14 @@ void NpyOutputFormat::writeColumns()
|
||||
case TypeIndex::UInt64: writeNumpyNumbers<ColumnUInt64, UInt64>(column, out); break;
|
||||
case TypeIndex::Float32: writeNumpyNumbers<ColumnFloat32, Float32>(column, out); break;
|
||||
case TypeIndex::Float64: writeNumpyNumbers<ColumnFloat64, Float64>(column, out); break;
|
||||
case TypeIndex::FixedString: writeNumpyStrings<ColumnFixedString>(column, numpy_data_type.size, out); break;
|
||||
case TypeIndex::String: writeNumpyStrings<ColumnString>(column, numpy_data_type.size, out); break;
|
||||
default: break;
|
||||
case TypeIndex::FixedString:
|
||||
writeNumpyStrings<ColumnFixedString>(column, numpy_data_type->getSize(), out);
|
||||
break;
|
||||
case TypeIndex::String:
|
||||
writeNumpyStrings<ColumnString>(column, numpy_data_type->getSize(), out);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <IO/WriteBufferFromVector.h>
|
||||
#include <Processors/Formats/IRowOutputFormat.h>
|
||||
#include <Formats/FormatSettings.h>
|
||||
#include <Formats/NumpyDataTypes.h>
|
||||
#include <Columns/IColumn.h>
|
||||
#include <Common/PODArray_fwd.h>
|
||||
|
||||
@ -28,18 +29,6 @@ public:
|
||||
String getContentType() const override { return "application/octet-stream"; }
|
||||
|
||||
private:
|
||||
struct NumpyDataType
|
||||
{
|
||||
char endianness;
|
||||
char type;
|
||||
size_t size;
|
||||
|
||||
NumpyDataType() = default;
|
||||
NumpyDataType(char endianness_, char type_, size_t size_)
|
||||
: endianness(endianness_), type(type_), size(size_) {}
|
||||
String str() const;
|
||||
};
|
||||
|
||||
String shapeStr() const;
|
||||
|
||||
bool getNumpyDataType(const DataTypePtr & type);
|
||||
@ -57,7 +46,7 @@ private:
|
||||
|
||||
DataTypePtr data_type;
|
||||
DataTypePtr nested_data_type;
|
||||
NumpyDataType numpy_data_type;
|
||||
std::shared_ptr<NumpyDataType> numpy_data_type;
|
||||
UInt64 num_rows = 0;
|
||||
std::vector<UInt64> numpy_shape;
|
||||
Columns columns;
|
||||
|
Loading…
Reference in New Issue
Block a user