diff --git a/docker/test/stateless/Dockerfile b/docker/test/stateless/Dockerfile index b063f8d81f6..10b213803c9 100644 --- a/docker/test/stateless/Dockerfile +++ b/docker/test/stateless/Dockerfile @@ -13,6 +13,7 @@ RUN apt-get update -y \ ncdu \ netcat-openbsd \ openssl \ + protobuf-compiler \ python3 \ python3-lxml \ python3-requests \ diff --git a/src/Columns/ColumnFixedString.cpp b/src/Columns/ColumnFixedString.cpp index 55e387ff2ee..6cfec89a5dc 100644 --- a/src/Columns/ColumnFixedString.cpp +++ b/src/Columns/ColumnFixedString.cpp @@ -446,4 +446,18 @@ void ColumnFixedString::getExtremes(Field & min, Field & max) const get(max_idx, max); } +void ColumnFixedString::alignStringLength(ColumnFixedString::Chars & data, size_t n, size_t old_size) +{ + size_t length = data.size() - old_size; + if (length < n) + { + data.resize_fill(old_size + n); + } + else if (length > n) + { + data.resize_assume_reserved(old_size); + throw Exception("Too large value for FixedString(" + std::to_string(n) + ")", ErrorCodes::TOO_LARGE_STRING_SIZE); + } +} + } diff --git a/src/Columns/ColumnFixedString.h b/src/Columns/ColumnFixedString.h index 286b3a752dc..24a99c27b13 100644 --- a/src/Columns/ColumnFixedString.h +++ b/src/Columns/ColumnFixedString.h @@ -182,7 +182,8 @@ public: const Chars & getChars() const { return chars; } size_t getN() const { return n; } + + static void alignStringLength(ColumnFixedString::Chars & data, size_t n, size_t old_size); }; - } diff --git a/src/Common/ErrorCodes.cpp b/src/Common/ErrorCodes.cpp index d0d83448b68..52c22c2e371 100644 --- a/src/Common/ErrorCodes.cpp +++ b/src/Common/ErrorCodes.cpp @@ -404,7 +404,7 @@ M(432, UNKNOWN_CODEC) \ M(433, ILLEGAL_CODEC_PARAMETER) \ M(434, CANNOT_PARSE_PROTOBUF_SCHEMA) \ - M(435, NO_DATA_FOR_REQUIRED_PROTOBUF_FIELD) \ + M(435, NO_COLUMN_SERIALIZED_TO_REQUIRED_PROTOBUF_FIELD) \ M(436, PROTOBUF_BAD_CAST) \ M(437, PROTOBUF_FIELD_NOT_REPEATED) \ M(438, DATA_TYPE_CANNOT_BE_PROMOTED) \ @@ -412,7 +412,7 @@ M(440, INVALID_LIMIT_EXPRESSION) \ M(441, CANNOT_PARSE_DOMAIN_VALUE_FROM_STRING) \ M(442, BAD_DATABASE_FOR_TEMPORARY_TABLE) \ - M(443, NO_COMMON_COLUMNS_WITH_PROTOBUF_SCHEMA) \ + M(443, NO_COLUMNS_SERIALIZED_TO_PROTOBUF_FIELDS) \ M(444, UNKNOWN_PROTOBUF_FORMAT) \ M(445, CANNOT_MPROTECT) \ M(446, FUNCTION_NOT_ALLOWED) \ @@ -535,6 +535,8 @@ M(566, CANNOT_RMDIR) \ M(567, DUPLICATED_PART_UUIDS) \ M(568, RAFT_ERROR) \ + M(569, MULTIPLE_COLUMNS_SERIALIZED_TO_SAME_PROTOBUF_FIELD) \ + M(570, DATA_TYPE_INCOMPATIBLE_WITH_PROTOBUF_FIELD) \ \ M(999, KEEPER_EXCEPTION) \ M(1000, POCO_EXCEPTION) \ diff --git a/src/DataTypes/DataTypeAggregateFunction.cpp b/src/DataTypes/DataTypeAggregateFunction.cpp index 9104c12120f..e92994ae979 100644 --- a/src/DataTypes/DataTypeAggregateFunction.cpp +++ b/src/DataTypes/DataTypeAggregateFunction.cpp @@ -10,8 +10,6 @@ #include #include -#include -#include #include #include #include @@ -261,45 +259,6 @@ void DataTypeAggregateFunction::deserializeTextCSV(IColumn & column, ReadBuffer } -void DataTypeAggregateFunction::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - if (value_index) - return; - value_index = static_cast( - protobuf.writeAggregateFunction(function, assert_cast(column).getData()[row_num])); -} - -void DataTypeAggregateFunction::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - row_added = false; - ColumnAggregateFunction & column_concrete = assert_cast(column); - Arena & arena = column_concrete.createOrGetArena(); - size_t size_of_state = function->sizeOfData(); - AggregateDataPtr place = arena.alignedAlloc(size_of_state, function->alignOfData()); - function->create(place); - try - { - if (!protobuf.readAggregateFunction(function, place, arena)) - { - function->destroy(place); - return; - } - auto & container = column_concrete.getData(); - if (allow_add_row) - { - container.emplace_back(place); - row_added = true; - } - else - container.back() = place; - } - catch (...) - { - function->destroy(place); - throw; - } -} - MutableColumnPtr DataTypeAggregateFunction::createColumn() const { return ColumnAggregateFunction::create(function); diff --git a/src/DataTypes/DataTypeAggregateFunction.h b/src/DataTypes/DataTypeAggregateFunction.h index 9ae7c67a803..d07d46fd3ee 100644 --- a/src/DataTypes/DataTypeAggregateFunction.h +++ b/src/DataTypes/DataTypeAggregateFunction.h @@ -59,8 +59,6 @@ public: void serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; void serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; void deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const override; - void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const override; MutableColumnPtr createColumn() const override; diff --git a/src/DataTypes/DataTypeArray.cpp b/src/DataTypes/DataTypeArray.cpp index 3ad84a8fcd7..27088ab822c 100644 --- a/src/DataTypes/DataTypeArray.cpp +++ b/src/DataTypes/DataTypeArray.cpp @@ -6,7 +6,6 @@ #include #include -#include #include #include #include @@ -522,55 +521,6 @@ void DataTypeArray::deserializeTextCSV(IColumn & column, ReadBuffer & istr, cons } -void DataTypeArray::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - const ColumnArray & column_array = assert_cast(column); - const ColumnArray::Offsets & offsets = column_array.getOffsets(); - size_t offset = offsets[row_num - 1] + value_index; - size_t next_offset = offsets[row_num]; - const IColumn & nested_column = column_array.getData(); - size_t i; - for (i = offset; i < next_offset; ++i) - { - size_t element_stored = 0; - nested->serializeProtobuf(nested_column, i, protobuf, element_stored); - if (!element_stored) - break; - } - value_index += i - offset; -} - - -void DataTypeArray::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - row_added = false; - ColumnArray & column_array = assert_cast(column); - IColumn & nested_column = column_array.getData(); - ColumnArray::Offsets & offsets = column_array.getOffsets(); - size_t old_size = offsets.size(); - try - { - bool nested_row_added; - do - nested->deserializeProtobuf(nested_column, protobuf, true, nested_row_added); - while (nested_row_added && protobuf.canReadMoreValues()); - if (allow_add_row) - { - offsets.emplace_back(nested_column.size()); - row_added = true; - } - else - offsets.back() = nested_column.size(); - } - catch (...) - { - offsets.resize_assume_reserved(old_size); - nested_column.popBack(nested_column.size() - offsets.back()); - throw; - } -} - - MutableColumnPtr DataTypeArray::createColumn() const { return ColumnArray::create(nested->createColumn(), ColumnArray::ColumnOffsets::create()); diff --git a/src/DataTypes/DataTypeArray.h b/src/DataTypes/DataTypeArray.h index ba19ad021be..4185163e2e7 100644 --- a/src/DataTypes/DataTypeArray.h +++ b/src/DataTypes/DataTypeArray.h @@ -85,15 +85,6 @@ public: DeserializeBinaryBulkStatePtr & state, SubstreamsCache * cache) const override; - void serializeProtobuf(const IColumn & column, - size_t row_num, - ProtobufWriter & protobuf, - size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, - ProtobufReader & protobuf, - bool allow_add_row, - bool & row_added) const override; - MutableColumnPtr createColumn() const override; Field getDefault() const override; diff --git a/src/DataTypes/DataTypeDate.cpp b/src/DataTypes/DataTypeDate.cpp index 2c1dfcbb0fe..192a89cc454 100644 --- a/src/DataTypes/DataTypeDate.cpp +++ b/src/DataTypes/DataTypeDate.cpp @@ -4,8 +4,6 @@ #include #include #include -#include -#include #include @@ -81,30 +79,6 @@ void DataTypeDate::deserializeTextCSV(IColumn & column, ReadBuffer & istr, const assert_cast(column).getData().push_back(value.getDayNum()); } -void DataTypeDate::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - if (value_index) - return; - value_index = static_cast(protobuf.writeDate(DayNum(assert_cast(column).getData()[row_num]))); -} - -void DataTypeDate::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - row_added = false; - DayNum d; - if (!protobuf.readDate(d)) - return; - - auto & container = assert_cast(column).getData(); - if (allow_add_row) - { - container.emplace_back(d); - row_added = true; - } - else - container.back() = d; -} - bool DataTypeDate::equals(const IDataType & rhs) const { return typeid(rhs) == typeid(*this); diff --git a/src/DataTypes/DataTypeDate.h b/src/DataTypes/DataTypeDate.h index 00afba424e4..496d7fe0b22 100644 --- a/src/DataTypes/DataTypeDate.h +++ b/src/DataTypes/DataTypeDate.h @@ -24,8 +24,6 @@ public: void deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings &) const override; void serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; void deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const override; - void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const override; bool canBeUsedAsVersion() const override { return true; } bool canBeInsideNullable() const override { return true; } diff --git a/src/DataTypes/DataTypeDateTime.cpp b/src/DataTypes/DataTypeDateTime.cpp index bfb4473e429..d2bbb4a1efa 100644 --- a/src/DataTypes/DataTypeDateTime.cpp +++ b/src/DataTypes/DataTypeDateTime.cpp @@ -5,8 +5,6 @@ #include #include #include -#include -#include #include #include #include @@ -164,32 +162,6 @@ void DataTypeDateTime::deserializeTextCSV(IColumn & column, ReadBuffer & istr, c assert_cast(column).getData().push_back(x); } -void DataTypeDateTime::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - if (value_index) - return; - - // On some platforms `time_t` is `long` but not `unsigned int` (UInt32 that we store in column), hence static_cast. - value_index = static_cast(protobuf.writeDateTime(static_cast(assert_cast(column).getData()[row_num]))); -} - -void DataTypeDateTime::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - row_added = false; - time_t t; - if (!protobuf.readDateTime(t)) - return; - - auto & container = assert_cast(column).getData(); - if (allow_add_row) - { - container.emplace_back(t); - row_added = true; - } - else - container.back() = t; -} - bool DataTypeDateTime::equals(const IDataType & rhs) const { /// DateTime with different timezones are equal, because: diff --git a/src/DataTypes/DataTypeDateTime.h b/src/DataTypes/DataTypeDateTime.h index 47c7f361091..edec889309b 100644 --- a/src/DataTypes/DataTypeDateTime.h +++ b/src/DataTypes/DataTypeDateTime.h @@ -68,8 +68,6 @@ public: void deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings &) const override; void serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; void deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const override; - void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const override; bool canBeUsedAsVersion() const override { return true; } bool canBeInsideNullable() const override { return true; } diff --git a/src/DataTypes/DataTypeDateTime64.cpp b/src/DataTypes/DataTypeDateTime64.cpp index ef1a971510a..09e39c2de1a 100644 --- a/src/DataTypes/DataTypeDateTime64.cpp +++ b/src/DataTypes/DataTypeDateTime64.cpp @@ -6,8 +6,6 @@ #include #include #include -#include -#include #include #include #include @@ -182,30 +180,6 @@ void DataTypeDateTime64::deserializeTextCSV(IColumn & column, ReadBuffer & istr, assert_cast(column).getData().push_back(x); } -void DataTypeDateTime64::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - if (value_index) - return; - value_index = static_cast(protobuf.writeDateTime64(assert_cast(column).getData()[row_num], scale)); -} - -void DataTypeDateTime64::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - row_added = false; - DateTime64 t = 0; - if (!protobuf.readDateTime64(t, scale)) - return; - - auto & container = assert_cast(column).getData(); - if (allow_add_row) - { - container.emplace_back(t); - row_added = true; - } - else - container.back() = t; -} - bool DataTypeDateTime64::equals(const IDataType & rhs) const { if (const auto * ptype = typeid_cast(&rhs)) diff --git a/src/DataTypes/DataTypeDateTime64.h b/src/DataTypes/DataTypeDateTime64.h index 003e83b7195..198c3739f58 100644 --- a/src/DataTypes/DataTypeDateTime64.h +++ b/src/DataTypes/DataTypeDateTime64.h @@ -42,8 +42,6 @@ public: void deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings &) const override; void serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; void deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const override; - void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const override; bool equals(const IDataType & rhs) const override; diff --git a/src/DataTypes/DataTypeDecimalBase.cpp b/src/DataTypes/DataTypeDecimalBase.cpp index 9fb445ab00d..ab17996167c 100644 --- a/src/DataTypes/DataTypeDecimalBase.cpp +++ b/src/DataTypes/DataTypeDecimalBase.cpp @@ -4,8 +4,6 @@ #include #include #include -#include -#include #include #include #include diff --git a/src/DataTypes/DataTypeEnum.cpp b/src/DataTypes/DataTypeEnum.cpp index 650a1da6407..043c971266c 100644 --- a/src/DataTypes/DataTypeEnum.cpp +++ b/src/DataTypes/DataTypeEnum.cpp @@ -1,7 +1,5 @@ #include #include -#include -#include #include #include #include @@ -254,34 +252,6 @@ void DataTypeEnum::deserializeBinaryBulk( x.resize(initial_size + size / sizeof(FieldType)); } -template -void DataTypeEnum::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - if (value_index) - return; - protobuf.prepareEnumMapping(values); - value_index = static_cast(protobuf.writeEnum(assert_cast(column).getData()[row_num])); -} - -template -void DataTypeEnum::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - protobuf.prepareEnumMapping(values); - row_added = false; - Type value; - if (!protobuf.readEnum(value)) - return; - - auto & container = assert_cast(column).getData(); - if (allow_add_row) - { - container.emplace_back(value); - row_added = true; - } - else - container.back() = value; -} - template Field DataTypeEnum::getDefault() const { diff --git a/src/DataTypes/DataTypeEnum.h b/src/DataTypes/DataTypeEnum.h index c75d348f15c..003613edb98 100644 --- a/src/DataTypes/DataTypeEnum.h +++ b/src/DataTypes/DataTypeEnum.h @@ -132,9 +132,6 @@ public: void serializeBinaryBulk(const IColumn & column, WriteBuffer & ostr, const size_t offset, size_t limit) const override; void deserializeBinaryBulk(IColumn & column, ReadBuffer & istr, const size_t limit, const double avg_value_size_hint) const override; - void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const override; - MutableColumnPtr createColumn() const override { return ColumnType::create(); } Field getDefault() const override; diff --git a/src/DataTypes/DataTypeFixedString.cpp b/src/DataTypes/DataTypeFixedString.cpp index 585c5709be7..21cfe855169 100644 --- a/src/DataTypes/DataTypeFixedString.cpp +++ b/src/DataTypes/DataTypeFixedString.cpp @@ -2,8 +2,6 @@ #include #include -#include -#include #include #include @@ -25,7 +23,6 @@ namespace DB namespace ErrorCodes { extern const int CANNOT_READ_ALL_DATA; - extern const int TOO_LARGE_STRING_SIZE; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int UNEXPECTED_AST_STRUCTURE; } @@ -127,16 +124,7 @@ static inline void alignStringLength(const DataTypeFixedString & type, ColumnFixedString::Chars & data, size_t string_start) { - size_t length = data.size() - string_start; - if (length < type.getN()) - { - data.resize_fill(string_start + type.getN()); - } - else if (length > type.getN()) - { - data.resize_assume_reserved(string_start); - throw Exception("Too large value for " + type.getName(), ErrorCodes::TOO_LARGE_STRING_SIZE); - } + ColumnFixedString::alignStringLength(data, type.getN(), string_start); } template @@ -215,53 +203,6 @@ void DataTypeFixedString::deserializeTextCSV(IColumn & column, ReadBuffer & istr } -void DataTypeFixedString::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - if (value_index) - return; - const char * pos = reinterpret_cast(&assert_cast(column).getChars()[n * row_num]); - value_index = static_cast(protobuf.writeString(StringRef(pos, n))); -} - - -void DataTypeFixedString::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - row_added = false; - auto & column_string = assert_cast(column); - ColumnFixedString::Chars & data = column_string.getChars(); - size_t old_size = data.size(); - try - { - if (allow_add_row) - { - if (protobuf.readStringInto(data)) - { - alignStringLength(*this, data, old_size); - row_added = true; - } - else - data.resize_assume_reserved(old_size); - } - else - { - ColumnFixedString::Chars temp_data; - if (protobuf.readStringInto(temp_data)) - { - alignStringLength(*this, temp_data, 0); - column_string.popBack(1); - old_size = data.size(); - data.insertSmallAllowReadWriteOverflow15(temp_data.begin(), temp_data.end()); - } - } - } - catch (...) - { - data.resize_assume_reserved(old_size); - throw; - } -} - - MutableColumnPtr DataTypeFixedString::createColumn() const { return ColumnFixedString::create(n); diff --git a/src/DataTypes/DataTypeFixedString.h b/src/DataTypes/DataTypeFixedString.h index e410d1b0596..af82e4b5d11 100644 --- a/src/DataTypes/DataTypeFixedString.h +++ b/src/DataTypes/DataTypeFixedString.h @@ -66,9 +66,6 @@ public: void serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; void deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings &) const override; - void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const override; - MutableColumnPtr createColumn() const override; Field getDefault() const override; diff --git a/src/DataTypes/DataTypeLowCardinality.cpp b/src/DataTypes/DataTypeLowCardinality.cpp index 9614c150c7d..1b21b7de4bc 100644 --- a/src/DataTypes/DataTypeLowCardinality.cpp +++ b/src/DataTypes/DataTypeLowCardinality.cpp @@ -808,31 +808,6 @@ void DataTypeLowCardinality::serializeTextXML(const IColumn & column, size_t row serializeImpl(column, row_num, &IDataType::serializeAsTextXML, ostr, settings); } -void DataTypeLowCardinality::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - serializeImpl(column, row_num, &IDataType::serializeProtobuf, protobuf, value_index); -} - -void DataTypeLowCardinality::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - if (allow_add_row) - { - deserializeImpl(column, &IDataType::deserializeProtobuf, protobuf, true, row_added); - return; - } - - row_added = false; - auto & low_cardinality_column= getColumnLowCardinality(column); - auto nested_column = low_cardinality_column.getDictionary().getNestedColumn(); - auto temp_column = nested_column->cloneEmpty(); - size_t unique_row_number = low_cardinality_column.getIndexes().getUInt(low_cardinality_column.size() - 1); - temp_column->insertFrom(*nested_column, unique_row_number); - bool dummy; - dictionary_type.get()->deserializeProtobuf(*temp_column, protobuf, false, dummy); - low_cardinality_column.popBack(1); - low_cardinality_column.insertFromFullColumn(*temp_column, 0); -} - template void DataTypeLowCardinality::serializeImpl( const IColumn & column, size_t row_num, DataTypeLowCardinality::SerializeFunctionPtr func, Args &&... args) const diff --git a/src/DataTypes/DataTypeLowCardinality.h b/src/DataTypes/DataTypeLowCardinality.h index 6ed2b792ce3..14beb423f1f 100644 --- a/src/DataTypes/DataTypeLowCardinality.h +++ b/src/DataTypes/DataTypeLowCardinality.h @@ -65,8 +65,6 @@ public: void serializeTextJSON(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const override; void deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const override; void serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings & settings) const override; - void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const override; MutableColumnPtr createColumn() const override; diff --git a/src/DataTypes/DataTypeMap.cpp b/src/DataTypes/DataTypeMap.cpp index af2ed8805e8..9972452862f 100644 --- a/src/DataTypes/DataTypeMap.cpp +++ b/src/DataTypes/DataTypeMap.cpp @@ -336,16 +336,6 @@ void DataTypeMap::deserializeBinaryBulkWithMultipleStreamsImpl( nested->deserializeBinaryBulkWithMultipleStreams(column_map.getNestedColumnPtr(), limit, settings, state, cache); } -void DataTypeMap::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - nested->serializeProtobuf(extractNestedColumn(column), row_num, protobuf, value_index); -} - -void DataTypeMap::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - nested->deserializeProtobuf(extractNestedColumn(column), protobuf, allow_add_row, row_added); -} - MutableColumnPtr DataTypeMap::createColumn() const { return ColumnMap::create(nested->createColumn()); diff --git a/src/DataTypes/DataTypeMap.h b/src/DataTypes/DataTypeMap.h index ea495f05548..88ea44a0d5a 100644 --- a/src/DataTypes/DataTypeMap.h +++ b/src/DataTypes/DataTypeMap.h @@ -76,9 +76,6 @@ public: DeserializeBinaryBulkStatePtr & state, SubstreamsCache * cache) const override; - void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const override; - MutableColumnPtr createColumn() const override; Field getDefault() const override; @@ -92,6 +89,8 @@ public: const DataTypePtr & getValueType() const { return value_type; } DataTypes getKeyValueTypes() const { return {key_type, value_type}; } + const DataTypePtr & getNestedType() const { return nested; } + private: template void serializeTextImpl(const IColumn & column, size_t row_num, WriteBuffer & ostr, Writer && writer) const; diff --git a/src/DataTypes/DataTypeNullable.cpp b/src/DataTypes/DataTypeNullable.cpp index c3b734686f8..903ebeb3ddc 100644 --- a/src/DataTypes/DataTypeNullable.cpp +++ b/src/DataTypes/DataTypeNullable.cpp @@ -486,33 +486,6 @@ void DataTypeNullable::serializeTextXML(const IColumn & column, size_t row_num, nested_data_type->serializeAsTextXML(col.getNestedColumn(), row_num, ostr, settings); } -void DataTypeNullable::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - const ColumnNullable & col = assert_cast(column); - if (!col.isNullAt(row_num)) - nested_data_type->serializeProtobuf(col.getNestedColumn(), row_num, protobuf, value_index); -} - -void DataTypeNullable::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - ColumnNullable & col = assert_cast(column); - IColumn & nested_column = col.getNestedColumn(); - size_t old_size = nested_column.size(); - try - { - nested_data_type->deserializeProtobuf(nested_column, protobuf, allow_add_row, row_added); - if (row_added) - col.getNullMapData().push_back(0); - } - catch (...) - { - nested_column.popBack(nested_column.size() - old_size); - col.getNullMapData().resize_assume_reserved(old_size); - row_added = false; - throw; - } -} - MutableColumnPtr DataTypeNullable::createColumn() const { return ColumnNullable::create(nested_data_type->createColumn(), ColumnUInt8::create()); diff --git a/src/DataTypes/DataTypeNullable.h b/src/DataTypes/DataTypeNullable.h index db641faf0af..5e71a1bee4d 100644 --- a/src/DataTypes/DataTypeNullable.h +++ b/src/DataTypes/DataTypeNullable.h @@ -73,9 +73,6 @@ public: void serializeText(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; void serializeTextXML(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; - void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const override; - MutableColumnPtr createColumn() const override; Field getDefault() const override; diff --git a/src/DataTypes/DataTypeNumberBase.cpp b/src/DataTypes/DataTypeNumberBase.cpp index a9b9bbc8090..ae3e6762d27 100644 --- a/src/DataTypes/DataTypeNumberBase.cpp +++ b/src/DataTypes/DataTypeNumberBase.cpp @@ -8,8 +8,6 @@ #include #include #include -#include -#include namespace DB @@ -205,34 +203,6 @@ void DataTypeNumberBase::deserializeBinaryBulk(IColumn & column, ReadBuffer & } -template -void DataTypeNumberBase::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - if (value_index) - return; - value_index = static_cast(protobuf.writeNumber(assert_cast &>(column).getData()[row_num])); -} - - -template -void DataTypeNumberBase::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - row_added = false; - T value; - if (!protobuf.readNumber(value)) - return; - - auto & container = typeid_cast &>(column).getData(); - if (allow_add_row) - { - container.emplace_back(value); - row_added = true; - } - else - container.back() = value; -} - - template MutableColumnPtr DataTypeNumberBase::createColumn() const { diff --git a/src/DataTypes/DataTypeNumberBase.h b/src/DataTypes/DataTypeNumberBase.h index 1491eabfbd5..22a70ac7277 100644 --- a/src/DataTypes/DataTypeNumberBase.h +++ b/src/DataTypes/DataTypeNumberBase.h @@ -45,9 +45,6 @@ public: void serializeBinaryBulk(const IColumn & column, WriteBuffer & ostr, size_t offset, size_t limit) const override; void deserializeBinaryBulk(IColumn & column, ReadBuffer & istr, size_t limit, double avg_value_size_hint) const override; - void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const override; - MutableColumnPtr createColumn() const override; bool isParametric() const override { return false; } diff --git a/src/DataTypes/DataTypeString.cpp b/src/DataTypes/DataTypeString.cpp index c752d136642..d760df5075d 100644 --- a/src/DataTypes/DataTypeString.cpp +++ b/src/DataTypes/DataTypeString.cpp @@ -9,8 +9,6 @@ #include #include -#include -#include #include #include @@ -311,55 +309,6 @@ void DataTypeString::deserializeTextCSV(IColumn & column, ReadBuffer & istr, con } -void DataTypeString::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - if (value_index) - return; - value_index = static_cast(protobuf.writeString(assert_cast(column).getDataAt(row_num))); -} - - -void DataTypeString::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - row_added = false; - auto & column_string = assert_cast(column); - ColumnString::Chars & data = column_string.getChars(); - ColumnString::Offsets & offsets = column_string.getOffsets(); - size_t old_size = offsets.size(); - try - { - if (allow_add_row) - { - if (protobuf.readStringInto(data)) - { - data.emplace_back(0); - offsets.emplace_back(data.size()); - row_added = true; - } - else - data.resize_assume_reserved(offsets.back()); - } - else - { - ColumnString::Chars temp_data; - if (protobuf.readStringInto(temp_data)) - { - temp_data.emplace_back(0); - column_string.popBack(1); - old_size = offsets.size(); - data.insertSmallAllowReadWriteOverflow15(temp_data.begin(), temp_data.end()); - offsets.emplace_back(data.size()); - } - } - } - catch (...) - { - offsets.resize_assume_reserved(old_size); - data.resize_assume_reserved(offsets.back()); - throw; - } -} - Field DataTypeString::getDefault() const { return String(); diff --git a/src/DataTypes/DataTypeString.h b/src/DataTypes/DataTypeString.h index f6db8fe73d4..7f8aa1fd0cf 100644 --- a/src/DataTypes/DataTypeString.h +++ b/src/DataTypes/DataTypeString.h @@ -47,9 +47,6 @@ public: void serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; void deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings &) const override; - void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const override; - MutableColumnPtr createColumn() const override; Field getDefault() const override; diff --git a/src/DataTypes/DataTypeTuple.cpp b/src/DataTypes/DataTypeTuple.cpp index c62aa1c1187..2261e776ea2 100644 --- a/src/DataTypes/DataTypeTuple.cpp +++ b/src/DataTypes/DataTypeTuple.cpp @@ -504,33 +504,6 @@ void DataTypeTuple::deserializeBinaryBulkWithMultipleStreamsImpl( settings.path.pop_back(); } -void DataTypeTuple::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - for (; value_index < elems.size(); ++value_index) - { - size_t stored = 0; - elems[value_index]->serializeProtobuf(extractElementColumn(column, value_index), row_num, protobuf, stored); - if (!stored) - break; - } -} - -void DataTypeTuple::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - row_added = false; - bool all_elements_get_row = true; - addElementSafe(elems, column, [&] - { - for (const auto & i : ext::range(0, ext::size(elems))) - { - bool element_row_added; - elems[i]->deserializeProtobuf(extractElementColumn(column, i), protobuf, allow_add_row, element_row_added); - all_elements_get_row &= element_row_added; - } - }); - row_added = all_elements_get_row; -} - MutableColumnPtr DataTypeTuple::createColumn() const { size_t size = elems.size(); diff --git a/src/DataTypes/DataTypeTuple.h b/src/DataTypes/DataTypeTuple.h index 0b28ebe5a63..12ccf574c0e 100644 --- a/src/DataTypes/DataTypeTuple.h +++ b/src/DataTypes/DataTypeTuple.h @@ -81,9 +81,6 @@ public: DeserializeBinaryBulkStatePtr & state, SubstreamsCache * cache) const override; - void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const override; - MutableColumnPtr createColumn() const override; Field getDefault() const override; diff --git a/src/DataTypes/DataTypeUUID.cpp b/src/DataTypes/DataTypeUUID.cpp index 94a043eb472..b66cbadaef0 100644 --- a/src/DataTypes/DataTypeUUID.cpp +++ b/src/DataTypes/DataTypeUUID.cpp @@ -1,8 +1,6 @@ #include #include #include -#include -#include #include #include #include @@ -79,30 +77,6 @@ void DataTypeUUID::deserializeTextCSV(IColumn & column, ReadBuffer & istr, const assert_cast(column).getData().push_back(value); } -void DataTypeUUID::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - if (value_index) - return; - value_index = static_cast(protobuf.writeUUID(UUID(assert_cast(column).getData()[row_num]))); -} - -void DataTypeUUID::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - row_added = false; - UUID uuid; - if (!protobuf.readUUID(uuid)) - return; - - auto & container = assert_cast(column).getData(); - if (allow_add_row) - { - container.emplace_back(uuid); - row_added = true; - } - else - container.back() = uuid; -} - bool DataTypeUUID::equals(const IDataType & rhs) const { return typeid(rhs) == typeid(*this); diff --git a/src/DataTypes/DataTypeUUID.h b/src/DataTypes/DataTypeUUID.h index 6290d05cc3b..de0c7c7d8cf 100644 --- a/src/DataTypes/DataTypeUUID.h +++ b/src/DataTypes/DataTypeUUID.h @@ -26,8 +26,6 @@ public: void deserializeTextJSON(IColumn & column, ReadBuffer & istr, const FormatSettings &) const override; void serializeTextCSV(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const override; void deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings & settings) const override; - void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const override; bool canBeUsedInBitOperations() const override { return true; } bool canBeInsideNullable() const override { return true; } diff --git a/src/DataTypes/DataTypesDecimal.cpp b/src/DataTypes/DataTypesDecimal.cpp index 6c325c5d371..e174a242462 100644 --- a/src/DataTypes/DataTypesDecimal.cpp +++ b/src/DataTypes/DataTypesDecimal.cpp @@ -4,8 +4,6 @@ #include #include #include -#include -#include #include #include #include @@ -111,33 +109,6 @@ T DataTypeDecimal::parseFromString(const String & str) const return x; } -template -void DataTypeDecimal::serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const -{ - if (value_index) - return; - value_index = static_cast(protobuf.writeDecimal(assert_cast(column).getData()[row_num], this->scale)); -} - - -template -void DataTypeDecimal::deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const -{ - row_added = false; - T decimal; - if (!protobuf.readDecimal(decimal, this->precision, this->scale)) - return; - - auto & container = assert_cast(column).getData(); - if (allow_add_row) - { - container.emplace_back(decimal); - row_added = true; - } - else - container.back() = decimal; -} - static DataTypePtr create(const ASTPtr & arguments) { diff --git a/src/DataTypes/DataTypesDecimal.h b/src/DataTypes/DataTypesDecimal.h index 3f7b4e2ac63..08f44c60c41 100644 --- a/src/DataTypes/DataTypesDecimal.h +++ b/src/DataTypes/DataTypesDecimal.h @@ -46,9 +46,6 @@ public: void deserializeText(IColumn & column, ReadBuffer & istr, const FormatSettings &) const override; void deserializeTextCSV(IColumn & column, ReadBuffer & istr, const FormatSettings &) const override; - void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const override; - void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const override; - bool equals(const IDataType & rhs) const override; T parseFromString(const String & str) const; diff --git a/src/DataTypes/IDataType.h b/src/DataTypes/IDataType.h index dba5bc3f5a9..c9c848a8037 100644 --- a/src/DataTypes/IDataType.h +++ b/src/DataTypes/IDataType.h @@ -26,9 +26,6 @@ class Field; using DataTypePtr = std::shared_ptr; using DataTypes = std::vector; -class ProtobufReader; -class ProtobufWriter; - struct NameAndTypePair; @@ -235,10 +232,6 @@ public: /// If method will throw an exception, then column will be in same state as before call to method. virtual void deserializeBinary(IColumn & column, ReadBuffer & istr) const = 0; - /** Serialize to a protobuf. */ - virtual void serializeProtobuf(const IColumn & column, size_t row_num, ProtobufWriter & protobuf, size_t & value_index) const = 0; - virtual void deserializeProtobuf(IColumn & column, ProtobufReader & protobuf, bool allow_add_row, bool & row_added) const = 0; - /** Text serialization with escaping but without quoting. */ void serializeAsTextEscaped(const IColumn & column, size_t row_num, WriteBuffer & ostr, const FormatSettings &) const; diff --git a/src/DataTypes/IDataTypeDummy.h b/src/DataTypes/IDataTypeDummy.h index f27359e5f74..08cc0778a6e 100644 --- a/src/DataTypes/IDataTypeDummy.h +++ b/src/DataTypes/IDataTypeDummy.h @@ -34,8 +34,6 @@ public: void deserializeBinaryBulk(IColumn &, ReadBuffer &, size_t, double) const override { throwNoSerialization(); } void serializeText(const IColumn &, size_t, WriteBuffer &, const FormatSettings &) const override { throwNoSerialization(); } void deserializeText(IColumn &, ReadBuffer &, const FormatSettings &) const override { throwNoSerialization(); } - void serializeProtobuf(const IColumn &, size_t, ProtobufWriter &, size_t &) const override { throwNoSerialization(); } - void deserializeProtobuf(IColumn &, ProtobufReader &, bool, bool &) const override { throwNoSerialization(); } MutableColumnPtr createColumn() const override { diff --git a/src/Formats/FormatSettings.h b/src/Formats/FormatSettings.h index 3f031fa2311..c1f02c65748 100644 --- a/src/Formats/FormatSettings.h +++ b/src/Formats/FormatSettings.h @@ -120,7 +120,6 @@ struct FormatSettings struct { - bool write_row_delimiters = true; /** * Some buffers (kafka / rabbit) split the rows internally using callback, * and always send one row per message, so we can push there formats @@ -128,7 +127,7 @@ struct FormatSettings * we have to enforce exporting at most one row in the format output, * because Protobuf without delimiters is not generally useful. */ - bool allow_many_rows_no_delimiters = false; + bool allow_multiple_rows_without_delimiter = false; } protobuf; struct diff --git a/src/Formats/ProtobufColumnMatcher.cpp b/src/Formats/ProtobufColumnMatcher.cpp deleted file mode 100644 index f4803d1af10..00000000000 --- a/src/Formats/ProtobufColumnMatcher.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include "ProtobufColumnMatcher.h" -#if USE_PROTOBUF -#include -#include -#include - - -namespace DB -{ -namespace ErrorCodes -{ - extern const int NO_COMMON_COLUMNS_WITH_PROTOBUF_SCHEMA; -} - - -namespace -{ - String columnNameToSearchableForm(const String & str) - { - return Poco::replace(Poco::toUpper(str), ".", "_"); - } -} - -namespace ProtobufColumnMatcher -{ - namespace details - { - ColumnNameMatcher::ColumnNameMatcher(const std::vector & column_names) : column_usage(column_names.size()) - { - column_usage.resize(column_names.size(), false); - for (size_t i = 0; i != column_names.size(); ++i) - column_name_to_index_map.emplace(columnNameToSearchableForm(column_names[i]), i); - } - - size_t ColumnNameMatcher::findColumn(const String & field_name) - { - auto it = column_name_to_index_map.find(columnNameToSearchableForm(field_name)); - if (it == column_name_to_index_map.end()) - return -1; - size_t column_index = it->second; - if (column_usage[column_index]) - return -1; - column_usage[column_index] = true; - return column_index; - } - - void throwNoCommonColumns() - { - throw Exception("No common columns with provided protobuf schema", ErrorCodes::NO_COMMON_COLUMNS_WITH_PROTOBUF_SCHEMA); - } - } -} - -} -#endif diff --git a/src/Formats/ProtobufColumnMatcher.h b/src/Formats/ProtobufColumnMatcher.h deleted file mode 100644 index 35521be7a9b..00000000000 --- a/src/Formats/ProtobufColumnMatcher.h +++ /dev/null @@ -1,196 +0,0 @@ -#pragma once - -#if !defined(ARCADIA_BUILD) -# include "config_formats.h" -#endif - -#if USE_PROTOBUF -# include -# include -# include -# include -# include -# include -# include - -namespace google -{ -namespace protobuf -{ - class Descriptor; - class FieldDescriptor; -} -} - - -namespace DB -{ -namespace ProtobufColumnMatcher -{ - struct DefaultTraits - { - using MessageData = boost::blank; - using FieldData = boost::blank; - }; - - template - struct Message; - - /// Represents a field in a protobuf message. - template - struct Field - { - const google::protobuf::FieldDescriptor * field_descriptor = nullptr; - - /// Same as field_descriptor->number(). - UInt32 field_number = 0; - - /// Index of a column; either 'column_index' or 'nested_message' is set. - size_t column_index = -1; - std::unique_ptr> nested_message; - - typename Traits::FieldData data; - }; - - /// Represents a protobuf message. - template - struct Message - { - std::vector> fields; - - /// Points to the parent message if this is a nested message. - Message * parent = nullptr; - size_t index_in_parent = -1; - - typename Traits::MessageData data; - }; - - /// Utility function finding matching columns for each protobuf field. - template - static std::unique_ptr> matchColumns( - const std::vector & column_names, - const google::protobuf::Descriptor * message_type); - - template - static std::unique_ptr> matchColumns( - const std::vector & column_names, - const google::protobuf::Descriptor * message_type, - std::vector & field_descriptors_without_match); - - namespace details - { - [[noreturn]] void throwNoCommonColumns(); - - class ColumnNameMatcher - { - public: - ColumnNameMatcher(const std::vector & column_names); - size_t findColumn(const String & field_name); - - private: - std::unordered_map column_name_to_index_map; - std::vector column_usage; - }; - - template - std::unique_ptr> matchColumnsRecursive( - ColumnNameMatcher & name_matcher, - const google::protobuf::Descriptor * message_type, - const String & field_name_prefix, - std::vector * field_descriptors_without_match) - { - auto message = std::make_unique>(); - for (int i = 0; i != message_type->field_count(); ++i) - { - const google::protobuf::FieldDescriptor * field_descriptor = message_type->field(i); - if ((field_descriptor->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) - || (field_descriptor->type() == google::protobuf::FieldDescriptor::TYPE_GROUP)) - { - auto nested_message = matchColumnsRecursive( - name_matcher, - field_descriptor->message_type(), - field_name_prefix + field_descriptor->name() + ".", - field_descriptors_without_match); - if (nested_message) - { - message->fields.emplace_back(); - auto & current_field = message->fields.back(); - current_field.field_number = field_descriptor->number(); - current_field.field_descriptor = field_descriptor; - current_field.nested_message = std::move(nested_message); - current_field.nested_message->parent = message.get(); - } - } - else - { - size_t column_index = name_matcher.findColumn(field_name_prefix + field_descriptor->name()); - if (column_index == static_cast(-1)) - { - if (field_descriptors_without_match) - field_descriptors_without_match->emplace_back(field_descriptor); - } - else - { - message->fields.emplace_back(); - auto & current_field = message->fields.back(); - current_field.field_number = field_descriptor->number(); - current_field.field_descriptor = field_descriptor; - current_field.column_index = column_index; - } - } - } - - if (message->fields.empty()) - return nullptr; - - // Columns should be sorted by field_number, it's necessary for writing protobufs and useful reading protobufs. - std::sort(message->fields.begin(), message->fields.end(), [](const Field & left, const Field & right) - { - return left.field_number < right.field_number; - }); - - for (size_t i = 0; i != message->fields.size(); ++i) - { - auto & field = message->fields[i]; - if (field.nested_message) - field.nested_message->index_in_parent = i; - } - - return message; - } - } - - template - static std::unique_ptr> matchColumnsImpl( - const std::vector & column_names, - const google::protobuf::Descriptor * message_type, - std::vector * field_descriptors_without_match) - { - details::ColumnNameMatcher name_matcher(column_names); - auto message = details::matchColumnsRecursive(name_matcher, message_type, "", field_descriptors_without_match); - if (!message) - details::throwNoCommonColumns(); - return message; - } - - template - static std::unique_ptr> matchColumns( - const std::vector & column_names, - const google::protobuf::Descriptor * message_type) - { - return matchColumnsImpl(column_names, message_type, nullptr); - } - - template - static std::unique_ptr> matchColumns( - const std::vector & column_names, - const google::protobuf::Descriptor * message_type, - std::vector & field_descriptors_without_match) - { - return matchColumnsImpl(column_names, message_type, &field_descriptors_without_match); - } -} - -} - -#endif diff --git a/src/Formats/ProtobufReader.cpp b/src/Formats/ProtobufReader.cpp index 8f28d279c06..0e05b59badf 100644 --- a/src/Formats/ProtobufReader.cpp +++ b/src/Formats/ProtobufReader.cpp @@ -1,14 +1,7 @@ #include "ProtobufReader.h" #if USE_PROTOBUF -# include -# include -# include -# include -# include -# include -# include -# include +# include namespace DB @@ -16,7 +9,6 @@ namespace DB namespace ErrorCodes { extern const int UNKNOWN_PROTOBUF_FORMAT; - extern const int PROTOBUF_BAD_CAST; } @@ -41,36 +33,21 @@ namespace constexpr Int64 END_OF_FILE = -3; Int64 decodeZigZag(UInt64 n) { return static_cast((n >> 1) ^ (~(n & 1) + 1)); } - } -// SimpleReader is an utility class to deserialize protobufs. -// Knows nothing about protobuf schemas, just provides useful functions to deserialize data. -ProtobufReader::SimpleReader::SimpleReader(ReadBuffer & in_, const bool use_length_delimiters_) +ProtobufReader::ProtobufReader(ReadBuffer & in_) : in(in_) - , cursor(0) - , current_message_level(0) - , current_message_end(0) - , field_end(0) - , last_string_pos(-1) - , use_length_delimiters(use_length_delimiters_) { } -[[noreturn]] void ProtobufReader::SimpleReader::throwUnknownFormat() const -{ - throw Exception(std::string("Protobuf messages are corrupted or don't match the provided schema.") + (use_length_delimiters ? " Please note that Protobuf stream is length-delimited: every message is prefixed by its length in varint." : ""), ErrorCodes::UNKNOWN_PROTOBUF_FORMAT); -} - -bool ProtobufReader::SimpleReader::startMessage() +void ProtobufReader::startMessage(bool with_length_delimiter_) { // Start reading a root message. assert(!current_message_level); - if (unlikely(in.eof())) - return false; - if (use_length_delimiters) + root_message_has_length_delimiter = with_length_delimiter_; + if (root_message_has_length_delimiter) { size_t size_of_message = readVarint(); current_message_end = cursor + size_of_message; @@ -80,11 +57,11 @@ bool ProtobufReader::SimpleReader::startMessage() current_message_end = END_OF_FILE; } ++current_message_level; + field_number = next_field_number = 0; field_end = cursor; - return true; } -void ProtobufReader::SimpleReader::endMessage(bool ignore_errors) +void ProtobufReader::endMessage(bool ignore_errors) { if (!current_message_level) return; @@ -94,6 +71,8 @@ void ProtobufReader::SimpleReader::endMessage(bool ignore_errors) { if (cursor < root_message_end) ignore(root_message_end - cursor); + else if (root_message_end == END_OF_FILE) + ignoreAll(); else if (ignore_errors) moveCursorBackward(cursor - root_message_end); else @@ -104,7 +83,7 @@ void ProtobufReader::SimpleReader::endMessage(bool ignore_errors) parent_message_ends.clear(); } -void ProtobufReader::SimpleReader::startNestedMessage() +void ProtobufReader::startNestedMessage() { assert(current_message_level >= 1); if ((cursor > field_end) && (field_end != END_OF_GROUP)) @@ -115,10 +94,11 @@ void ProtobufReader::SimpleReader::startNestedMessage() parent_message_ends.emplace_back(current_message_end); current_message_end = field_end; ++current_message_level; + field_number = next_field_number = 0; field_end = cursor; } -void ProtobufReader::SimpleReader::endNestedMessage() +void ProtobufReader::endNestedMessage() { assert(current_message_level >= 2); if (cursor != current_message_end) @@ -137,12 +117,20 @@ void ProtobufReader::SimpleReader::endNestedMessage() --current_message_level; current_message_end = parent_message_ends.back(); parent_message_ends.pop_back(); + field_number = next_field_number = 0; field_end = cursor; } -bool ProtobufReader::SimpleReader::readFieldNumber(UInt32 & field_number) +bool ProtobufReader::readFieldNumber(int & field_number_) { assert(current_message_level); + if (next_field_number) + { + field_number_ = field_number = next_field_number; + next_field_number = 0; + return true; + } + if (field_end != cursor) { if (field_end == END_OF_VARINT) @@ -183,7 +171,8 @@ bool ProtobufReader::SimpleReader::readFieldNumber(UInt32 & field_number) if (unlikely(varint & (static_cast(0xFFFFFFFF) << 32))) throwUnknownFormat(); UInt32 key = static_cast(varint); - field_number = (key >> 3); + field_number_ = field_number = (key >> 3); + next_field_number = 0; WireType wire_type = static_cast(key & 0x07); switch (wire_type) { @@ -224,77 +213,91 @@ bool ProtobufReader::SimpleReader::readFieldNumber(UInt32 & field_number) throwUnknownFormat(); } -bool ProtobufReader::SimpleReader::readUInt(UInt64 & value) +UInt64 ProtobufReader::readUInt() { + UInt64 value; if (field_end == END_OF_VARINT) { value = readVarint(); field_end = cursor; - return true; } - - if (unlikely(cursor >= field_end)) - return false; - - value = readVarint(); - return true; + else + { + value = readVarint(); + if (cursor < field_end) + next_field_number = field_number; + else if (unlikely(cursor) > field_end) + throwUnknownFormat(); + } + return value; } -bool ProtobufReader::SimpleReader::readInt(Int64 & value) +Int64 ProtobufReader::readInt() { - UInt64 varint; - if (!readUInt(varint)) - return false; - value = static_cast(varint); - return true; + return static_cast(readUInt()); } -bool ProtobufReader::SimpleReader::readSInt(Int64 & value) +Int64 ProtobufReader::readSInt() { - UInt64 varint; - if (!readUInt(varint)) - return false; - value = decodeZigZag(varint); - return true; + return decodeZigZag(readUInt()); } template -bool ProtobufReader::SimpleReader::readFixed(T & value) +T ProtobufReader::readFixed() { - if (unlikely(cursor >= field_end)) - return false; - + if (unlikely(cursor + static_cast(sizeof(T)) > field_end)) + throwUnknownFormat(); + T value; readBinary(&value, sizeof(T)); - return true; + if (cursor < field_end) + next_field_number = field_number; + return value; } -bool ProtobufReader::SimpleReader::readStringInto(PaddedPODArray & str) +template Int32 ProtobufReader::readFixed(); +template UInt32 ProtobufReader::readFixed(); +template Int64 ProtobufReader::readFixed(); +template UInt64 ProtobufReader::readFixed(); +template Float32 ProtobufReader::readFixed(); +template Float64 ProtobufReader::readFixed(); + +void ProtobufReader::readString(String & str) +{ + if (unlikely(cursor > field_end)) + throwUnknownFormat(); + size_t length = field_end - cursor; + str.resize(length); + readBinary(reinterpret_cast(str.data()), length); +} + +void ProtobufReader::readStringAndAppend(PaddedPODArray & str) { - if (unlikely(cursor == last_string_pos)) - return false; /// We don't want to read the same empty string again. - last_string_pos = cursor; if (unlikely(cursor > field_end)) throwUnknownFormat(); size_t length = field_end - cursor; size_t old_size = str.size(); str.resize(old_size + length); readBinary(reinterpret_cast(str.data() + old_size), length); - return true; } -void ProtobufReader::SimpleReader::readBinary(void* data, size_t size) +void ProtobufReader::readBinary(void* data, size_t size) { in.readStrict(reinterpret_cast(data), size); cursor += size; } -void ProtobufReader::SimpleReader::ignore(UInt64 num_bytes) +void ProtobufReader::ignore(UInt64 num_bytes) { in.ignore(num_bytes); cursor += num_bytes; } -void ProtobufReader::SimpleReader::moveCursorBackward(UInt64 num_bytes) +void ProtobufReader::ignoreAll() +{ + cursor += in.tryIgnore(std::numeric_limits::max()); +} + +void ProtobufReader::moveCursorBackward(UInt64 num_bytes) { if (in.offset() < num_bytes) throwUnknownFormat(); @@ -302,7 +305,7 @@ void ProtobufReader::SimpleReader::moveCursorBackward(UInt64 num_bytes) cursor -= num_bytes; } -UInt64 ProtobufReader::SimpleReader::continueReadingVarint(UInt64 first_byte) +UInt64 ProtobufReader::continueReadingVarint(UInt64 first_byte) { UInt64 result = (first_byte & ~static_cast(0x80)); char c; @@ -342,7 +345,7 @@ UInt64 ProtobufReader::SimpleReader::continueReadingVarint(UInt64 first_byte) throwUnknownFormat(); } -void ProtobufReader::SimpleReader::ignoreVarint() +void ProtobufReader::ignoreVarint() { char c; @@ -379,7 +382,7 @@ void ProtobufReader::SimpleReader::ignoreVarint() throwUnknownFormat(); } -void ProtobufReader::SimpleReader::ignoreGroup() +void ProtobufReader::ignoreGroup() { size_t level = 1; while (true) @@ -424,803 +427,15 @@ void ProtobufReader::SimpleReader::ignoreGroup() } } -// Implementation for a converter from any protobuf field type to any DB data type. -class ProtobufReader::ConverterBaseImpl : public ProtobufReader::IConverter +[[noreturn]] void ProtobufReader::throwUnknownFormat() const { -public: - ConverterBaseImpl(SimpleReader & simple_reader_, const google::protobuf::FieldDescriptor * field_) - : simple_reader(simple_reader_), field(field_) {} - - bool readStringInto(PaddedPODArray &) override - { - cannotConvertType("String"); - } - - bool readInt8(Int8 &) override - { - cannotConvertType("Int8"); - } - - bool readUInt8(UInt8 &) override - { - cannotConvertType("UInt8"); - } - - bool readInt16(Int16 &) override - { - cannotConvertType("Int16"); - } - - bool readUInt16(UInt16 &) override - { - cannotConvertType("UInt16"); - } - - bool readInt32(Int32 &) override - { - cannotConvertType("Int32"); - } - - bool readUInt32(UInt32 &) override - { - cannotConvertType("UInt32"); - } - - bool readInt64(Int64 &) override - { - cannotConvertType("Int64"); - } - - bool readUInt64(UInt64 &) override - { - cannotConvertType("UInt64"); - } - - bool readUInt128(UInt128 &) override - { - cannotConvertType("UInt128"); - } - - bool readInt128(Int128 &) override { cannotConvertType("Int128"); } - bool readInt256(Int256 &) override { cannotConvertType("Int256"); } - bool readUInt256(UInt256 &) override { cannotConvertType("UInt256"); } - - bool readFloat32(Float32 &) override - { - cannotConvertType("Float32"); - } - - bool readFloat64(Float64 &) override - { - cannotConvertType("Float64"); - } - - void prepareEnumMapping8(const std::vector> &) override {} - void prepareEnumMapping16(const std::vector> &) override {} - - bool readEnum8(Int8 &) override - { - cannotConvertType("Enum"); - } - - bool readEnum16(Int16 &) override - { - cannotConvertType("Enum"); - } - - bool readUUID(UUID &) override - { - cannotConvertType("UUID"); - } - - bool readDate(DayNum &) override - { - cannotConvertType("Date"); - } - - bool readDateTime(time_t &) override - { - cannotConvertType("DateTime"); - } - - bool readDateTime64(DateTime64 &, UInt32) override - { - cannotConvertType("DateTime64"); - } - - bool readDecimal32(Decimal32 &, UInt32, UInt32) override - { - cannotConvertType("Decimal32"); - } - - bool readDecimal64(Decimal64 &, UInt32, UInt32) override - { - cannotConvertType("Decimal64"); - } - - bool readDecimal128(Decimal128 &, UInt32, UInt32) override - { - cannotConvertType("Decimal128"); - } - - bool readDecimal256(Decimal256 &, UInt32, UInt32) override - { - cannotConvertType("Decimal256"); - } - - - bool readAggregateFunction(const AggregateFunctionPtr &, AggregateDataPtr, Arena &) override - { - cannotConvertType("AggregateFunction"); - } - -protected: - [[noreturn]] void cannotConvertType(const String & type_name) - { - throw Exception( - String("Could not convert type '") + field->type_name() + "' from protobuf field '" + field->name() + "' to data type '" - + type_name + "'", - ErrorCodes::PROTOBUF_BAD_CAST); - } - - [[noreturn]] void cannotConvertValue(const String & value, const String & type_name) - { - throw Exception( - "Could not convert value '" + value + "' from protobuf field '" + field->name() + "' to data type '" + type_name + "'", - ErrorCodes::PROTOBUF_BAD_CAST); - } - - template - To numericCast(From value) - { - if constexpr (std::is_same_v) - return value; - To result; - try - { - result = boost::numeric_cast(value); - } - catch (boost::numeric::bad_numeric_cast &) - { - cannotConvertValue(toString(value), TypeName::get()); - } - return result; - } - - template - To parseFromString(const PaddedPODArray & str) - { - try - { - To result; - ReadBufferFromString buf(str); - readText(result, buf); - return result; - } - catch (...) - { - cannotConvertValue(StringRef(str.data(), str.size()).toString(), TypeName::get()); - } - } - - SimpleReader & simple_reader; - const google::protobuf::FieldDescriptor * field; -}; - - -class ProtobufReader::ConverterFromString : public ConverterBaseImpl -{ -public: - using ConverterBaseImpl::ConverterBaseImpl; - - bool readStringInto(PaddedPODArray & str) override { return simple_reader.readStringInto(str); } - - bool readInt8(Int8 & value) override { return readNumeric(value); } - bool readUInt8(UInt8 & value) override { return readNumeric(value); } - bool readInt16(Int16 & value) override { return readNumeric(value); } - bool readUInt16(UInt16 & value) override { return readNumeric(value); } - bool readInt32(Int32 & value) override { return readNumeric(value); } - bool readUInt32(UInt32 & value) override { return readNumeric(value); } - bool readInt64(Int64 & value) override { return readNumeric(value); } - bool readUInt64(UInt64 & value) override { return readNumeric(value); } - bool readFloat32(Float32 & value) override { return readNumeric(value); } - bool readFloat64(Float64 & value) override { return readNumeric(value); } - - void prepareEnumMapping8(const std::vector> & name_value_pairs) override - { - prepareEnumNameToValueMap(name_value_pairs); - } - void prepareEnumMapping16(const std::vector> & name_value_pairs) override - { - prepareEnumNameToValueMap(name_value_pairs); - } - - bool readEnum8(Int8 & value) override { return readEnum(value); } - bool readEnum16(Int16 & value) override { return readEnum(value); } - - bool readUUID(UUID & uuid) override - { - if (!readTempString()) - return false; - ReadBufferFromString buf(temp_string); - readUUIDText(uuid, buf); - return true; - } - - bool readDate(DayNum & date) override - { - if (!readTempString()) - return false; - ReadBufferFromString buf(temp_string); - readDateText(date, buf); - return true; - } - - bool readDateTime(time_t & tm) override - { - if (!readTempString()) - return false; - ReadBufferFromString buf(temp_string); - readDateTimeText(tm, buf); - return true; - } - - bool readDateTime64(DateTime64 & date_time, UInt32 scale) override - { - if (!readTempString()) - return false; - ReadBufferFromString buf(temp_string); - readDateTime64Text(date_time, scale, buf); - return true; - } - - bool readDecimal32(Decimal32 & decimal, UInt32 precision, UInt32 scale) override { return readDecimal(decimal, precision, scale); } - bool readDecimal64(Decimal64 & decimal, UInt32 precision, UInt32 scale) override { return readDecimal(decimal, precision, scale); } - bool readDecimal128(Decimal128 & decimal, UInt32 precision, UInt32 scale) override { return readDecimal(decimal, precision, scale); } - bool readDecimal256(Decimal256 & decimal, UInt32 precision, UInt32 scale) override { return readDecimal(decimal, precision, scale); } - - bool readAggregateFunction(const AggregateFunctionPtr & function, AggregateDataPtr place, Arena & arena) override - { - if (!readTempString()) - return false; - ReadBufferFromString buf(temp_string); - function->deserialize(place, buf, &arena); - return true; - } - -private: - bool readTempString() - { - temp_string.clear(); - return simple_reader.readStringInto(temp_string); - } - - template - bool readNumeric(T & value) - { - if (!readTempString()) - return false; - value = parseFromString(temp_string); - return true; - } - - template - bool readEnum(T & value) - { - if (!readTempString()) - return false; - StringRef ref(temp_string.data(), temp_string.size()); - auto it = enum_name_to_value_map->find(ref); - if (it == enum_name_to_value_map->end()) - cannotConvertValue(ref.toString(), "Enum"); - value = static_cast(it->second); - return true; - } - - template - bool readDecimal(Decimal & decimal, UInt32 precision, UInt32 scale) - { - if (!readTempString()) - return false; - ReadBufferFromString buf(temp_string); - DataTypeDecimal>::readText(decimal, buf, precision, scale); - return true; - } - - template - void prepareEnumNameToValueMap(const std::vector> & name_value_pairs) - { - if (likely(enum_name_to_value_map.has_value())) - return; - enum_name_to_value_map.emplace(); - for (const auto & name_value_pair : name_value_pairs) - enum_name_to_value_map->emplace(name_value_pair.first, name_value_pair.second); - } - - PaddedPODArray temp_string; - std::optional> enum_name_to_value_map; -}; - -# define PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS(field_type_id) \ - template <> \ - std::unique_ptr ProtobufReader::createConverter( \ - const google::protobuf::FieldDescriptor * field) \ - { \ - return std::make_unique(simple_reader, field); \ - } -PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS(google::protobuf::FieldDescriptor::TYPE_STRING) -PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS(google::protobuf::FieldDescriptor::TYPE_BYTES) - -# undef PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS - - -template -class ProtobufReader::ConverterFromNumber : public ConverterBaseImpl -{ -public: - using ConverterBaseImpl::ConverterBaseImpl; - - bool readStringInto(PaddedPODArray & str) override - { - FromType number; - if (!readField(number)) - return false; - WriteBufferFromVector> buf(str); - writeText(number, buf); - return true; - } - - bool readInt8(Int8 & value) override { return readNumeric(value); } - bool readUInt8(UInt8 & value) override { return readNumeric(value); } - bool readInt16(Int16 & value) override { return readNumeric(value); } - bool readUInt16(UInt16 & value) override { return readNumeric(value); } - bool readInt32(Int32 & value) override { return readNumeric(value); } - bool readUInt32(UInt32 & value) override { return readNumeric(value); } - bool readInt64(Int64 & value) override { return readNumeric(value); } - bool readUInt64(UInt64 & value) override { return readNumeric(value); } - bool readFloat32(Float32 & value) override { return readNumeric(value); } - bool readFloat64(Float64 & value) override { return readNumeric(value); } - - bool readEnum8(Int8 & value) override { return readEnum(value); } - bool readEnum16(Int16 & value) override { return readEnum(value); } - - void prepareEnumMapping8(const std::vector> & name_value_pairs) override - { - prepareSetOfEnumValues(name_value_pairs); - } - void prepareEnumMapping16(const std::vector> & name_value_pairs) override - { - prepareSetOfEnumValues(name_value_pairs); - } - - bool readDate(DayNum & date) override - { - UInt16 number; - if (!readNumeric(number)) - return false; - date = DayNum(number); - return true; - } - - bool readDateTime(time_t & tm) override - { - UInt32 number; - if (!readNumeric(number)) - return false; - tm = number; - return true; - } - - bool readDateTime64(DateTime64 & date_time, UInt32 scale) override - { - return readDecimal(date_time, scale); - } - - bool readDecimal32(Decimal32 & decimal, UInt32, UInt32 scale) override { return readDecimal(decimal, scale); } - bool readDecimal64(Decimal64 & decimal, UInt32, UInt32 scale) override { return readDecimal(decimal, scale); } - bool readDecimal128(Decimal128 & decimal, UInt32, UInt32 scale) override { return readDecimal(decimal, scale); } - -private: - template - bool readNumeric(To & value) - { - FromType number; - if (!readField(number)) - return false; - value = numericCast(number); - return true; - } - - template - bool readEnum(EnumType & value) - { - if constexpr (!is_integer_v) - cannotConvertType("Enum"); // It's not correct to convert floating point to enum. - FromType number; - if (!readField(number)) - return false; - value = numericCast(number); - if (set_of_enum_values->find(value) == set_of_enum_values->end()) - cannotConvertValue(toString(value), "Enum"); - return true; - } - - template - void prepareSetOfEnumValues(const std::vector> & name_value_pairs) - { - if (likely(set_of_enum_values.has_value())) - return; - set_of_enum_values.emplace(); - for (const auto & name_value_pair : name_value_pairs) - set_of_enum_values->emplace(name_value_pair.second); - } - - template - bool readDecimal(Decimal & decimal, UInt32 scale) - { - FromType number; - if (!readField(number)) - return false; - decimal.value = convertToDecimal, DataTypeDecimal>>(number, scale); - return true; - } - - bool readField(FromType & value) - { - if constexpr (((field_type_id == google::protobuf::FieldDescriptor::TYPE_INT32) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_INT64) && std::is_same_v)) - { - return simple_reader.readInt(value); - } - else if constexpr (((field_type_id == google::protobuf::FieldDescriptor::TYPE_UINT32) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_UINT64) && std::is_same_v)) - { - return simple_reader.readUInt(value); - } - - else if constexpr (((field_type_id == google::protobuf::FieldDescriptor::TYPE_SINT32) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_SINT64) && std::is_same_v)) - { - return simple_reader.readSInt(value); - } - else - { - static_assert(((field_type_id == google::protobuf::FieldDescriptor::TYPE_FIXED32) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_SFIXED32) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_FIXED64) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_SFIXED64) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_FLOAT) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_DOUBLE) && std::is_same_v)); - return simple_reader.readFixed(value); - } - } - - std::optional> set_of_enum_values; -}; - -# define PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(field_type_id, field_type) \ - template <> \ - std::unique_ptr ProtobufReader::createConverter( \ - const google::protobuf::FieldDescriptor * field) \ - { \ - return std::make_unique>(simple_reader, field); /* NOLINT */ \ - } - -PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_INT32, Int64); -PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SINT32, Int64); -PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_UINT32, UInt64); -PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_INT64, Int64); -PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SINT64, Int64); -PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_UINT64, UInt64); -PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_FIXED32, UInt32); -PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SFIXED32, Int32); -PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_FIXED64, UInt64); -PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SFIXED64, Int64); -PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_FLOAT, float); -PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_DOUBLE, double); - -# undef PROTOBUF_READER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS - - -class ProtobufReader::ConverterFromBool : public ConverterBaseImpl -{ -public: - using ConverterBaseImpl::ConverterBaseImpl; - - bool readStringInto(PaddedPODArray & str) override - { - bool b; - if (!readField(b)) - return false; - StringRef ref(b ? "true" : "false"); - str.insert(ref.data, ref.data + ref.size); - return true; - } - - bool readInt8(Int8 & value) override { return readNumeric(value); } - bool readUInt8(UInt8 & value) override { return readNumeric(value); } - bool readInt16(Int16 & value) override { return readNumeric(value); } - bool readUInt16(UInt16 & value) override { return readNumeric(value); } - bool readInt32(Int32 & value) override { return readNumeric(value); } - bool readUInt32(UInt32 & value) override { return readNumeric(value); } - bool readInt64(Int64 & value) override { return readNumeric(value); } - bool readUInt64(UInt64 & value) override { return readNumeric(value); } - bool readFloat32(Float32 & value) override { return readNumeric(value); } - bool readFloat64(Float64 & value) override { return readNumeric(value); } - bool readDecimal32(Decimal32 & decimal, UInt32, UInt32) override { return readNumeric(decimal.value); } - bool readDecimal64(Decimal64 & decimal, UInt32, UInt32) override { return readNumeric(decimal.value); } - bool readDecimal128(Decimal128 & decimal, UInt32, UInt32) override { return readNumeric(decimal.value); } - -private: - template - bool readNumeric(T & value) - { - bool b; - if (!readField(b)) - return false; - value = b ? 1 : 0; - return true; - } - - bool readField(bool & b) - { - UInt64 number; - if (!simple_reader.readUInt(number)) - return false; - b = static_cast(number); - return true; - } -}; - -template <> -std::unique_ptr ProtobufReader::createConverter( - const google::protobuf::FieldDescriptor * field) -{ - return std::make_unique(simple_reader, field); + throw Exception( + std::string("Protobuf messages are corrupted or don't match the provided schema.") + + (root_message_has_length_delimiter + ? " Please note that Protobuf stream is length-delimited: every message is prefixed by its length in varint." + : ""), + ErrorCodes::UNKNOWN_PROTOBUF_FORMAT); } - - -class ProtobufReader::ConverterFromEnum : public ConverterBaseImpl -{ -public: - using ConverterBaseImpl::ConverterBaseImpl; - - bool readStringInto(PaddedPODArray & str) override - { - prepareEnumPbNumberToNameMap(); - Int64 pbnumber; - if (!readField(pbnumber)) - return false; - auto it = enum_pbnumber_to_name_map->find(pbnumber); - if (it == enum_pbnumber_to_name_map->end()) - cannotConvertValue(toString(pbnumber), "Enum"); - const auto & ref = it->second; - str.insert(ref.data, ref.data + ref.size); - return true; - } - - bool readInt8(Int8 & value) override { return readNumeric(value); } - bool readUInt8(UInt8 & value) override { return readNumeric(value); } - bool readInt16(Int16 & value) override { return readNumeric(value); } - bool readUInt16(UInt16 & value) override { return readNumeric(value); } - bool readInt32(Int32 & value) override { return readNumeric(value); } - bool readUInt32(UInt32 & value) override { return readNumeric(value); } - bool readInt64(Int64 & value) override { return readNumeric(value); } - bool readUInt64(UInt64 & value) override { return readNumeric(value); } - - void prepareEnumMapping8(const std::vector> & name_value_pairs) override - { - prepareEnumPbNumberToValueMap(name_value_pairs); - } - void prepareEnumMapping16(const std::vector> & name_value_pairs) override - { - prepareEnumPbNumberToValueMap(name_value_pairs); - } - - bool readEnum8(Int8 & value) override { return readEnum(value); } - bool readEnum16(Int16 & value) override { return readEnum(value); } - -private: - template - bool readNumeric(T & value) - { - Int64 pbnumber; - if (!readField(pbnumber)) - return false; - value = numericCast(pbnumber); - return true; - } - - template - bool readEnum(T & value) - { - Int64 pbnumber; - if (!readField(pbnumber)) - return false; - if (enum_pbnumber_always_equals_value) - value = static_cast(pbnumber); - else - { - auto it = enum_pbnumber_to_value_map->find(pbnumber); - if (it == enum_pbnumber_to_value_map->end()) - cannotConvertValue(toString(pbnumber), "Enum"); - value = static_cast(it->second); - } - return true; - } - - void prepareEnumPbNumberToNameMap() - { - if (likely(enum_pbnumber_to_name_map.has_value())) - return; - enum_pbnumber_to_name_map.emplace(); - const auto * enum_type = field->enum_type(); - for (int i = 0; i != enum_type->value_count(); ++i) - { - const auto * enum_value = enum_type->value(i); - enum_pbnumber_to_name_map->emplace(enum_value->number(), enum_value->name()); - } - } - - template - void prepareEnumPbNumberToValueMap(const std::vector> & name_value_pairs) - { - if (likely(enum_pbnumber_to_value_map.has_value())) - return; - enum_pbnumber_to_value_map.emplace(); - enum_pbnumber_always_equals_value = true; - for (const auto & name_value_pair : name_value_pairs) - { - Int16 value = name_value_pair.second; // NOLINT - const auto * enum_descriptor = field->enum_type()->FindValueByName(name_value_pair.first); - if (enum_descriptor) - { - enum_pbnumber_to_value_map->emplace(enum_descriptor->number(), value); - if (enum_descriptor->number() != value) - enum_pbnumber_always_equals_value = false; - } - else - enum_pbnumber_always_equals_value = false; - } - } - - bool readField(Int64 & enum_pbnumber) - { - return simple_reader.readInt(enum_pbnumber); - } - - std::optional> enum_pbnumber_to_name_map; - std::optional> enum_pbnumber_to_value_map; - bool enum_pbnumber_always_equals_value; -}; - -template <> -std::unique_ptr ProtobufReader::createConverter( - const google::protobuf::FieldDescriptor * field) -{ - return std::make_unique(simple_reader, field); -} - - -ProtobufReader::ProtobufReader( - ReadBuffer & in_, const google::protobuf::Descriptor * message_type, const std::vector & column_names, const bool use_length_delimiters_) - : simple_reader(in_, use_length_delimiters_) -{ - root_message = ProtobufColumnMatcher::matchColumns(column_names, message_type); - setTraitsDataAfterMatchingColumns(root_message.get()); -} - -ProtobufReader::~ProtobufReader() = default; - -void ProtobufReader::setTraitsDataAfterMatchingColumns(Message * message) -{ - for (Field & field : message->fields) - { - if (field.nested_message) - { - setTraitsDataAfterMatchingColumns(field.nested_message.get()); - continue; - } - switch (field.field_descriptor->type()) - { -# define PROTOBUF_READER_CONVERTER_CREATING_CASE(field_type_id) \ - case field_type_id: \ - field.data.converter = createConverter(field.field_descriptor); \ - break - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_STRING); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_BYTES); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_INT32); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SINT32); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_UINT32); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_FIXED32); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SFIXED32); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_INT64); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SINT64); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_UINT64); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_FIXED64); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SFIXED64); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_FLOAT); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_DOUBLE); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_BOOL); - PROTOBUF_READER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_ENUM); -# undef PROTOBUF_READER_CONVERTER_CREATING_CASE - default: - __builtin_unreachable(); - } - message->data.field_number_to_field_map.emplace(field.field_number, &field); - } -} - -bool ProtobufReader::startMessage() -{ - if (!simple_reader.startMessage()) - return false; - current_message = root_message.get(); - current_field_index = 0; - return true; -} - -void ProtobufReader::endMessage(bool try_ignore_errors) -{ - simple_reader.endMessage(try_ignore_errors); - current_message = nullptr; - current_converter = nullptr; -} - -bool ProtobufReader::readColumnIndex(size_t & column_index) -{ - while (true) - { - UInt32 field_number; - if (!simple_reader.readFieldNumber(field_number)) - { - if (!current_message->parent) - { - current_converter = nullptr; - return false; - } - simple_reader.endNestedMessage(); - current_field_index = current_message->index_in_parent; - current_message = current_message->parent; - continue; - } - - const Field * field = nullptr; - for (; current_field_index < current_message->fields.size(); ++current_field_index) - { - const Field & f = current_message->fields[current_field_index]; - if (f.field_number == field_number) - { - field = &f; - break; - } - if (f.field_number > field_number) - break; - } - - if (!field) - { - const auto & field_number_to_field_map = current_message->data.field_number_to_field_map; - auto it = field_number_to_field_map.find(field_number); - if (it == field_number_to_field_map.end()) - continue; - field = it->second; - } - - if (field->nested_message) - { - simple_reader.startNestedMessage(); - current_message = field->nested_message.get(); - current_field_index = 0; - continue; - } - - column_index = field->column_index; - current_converter = field->data.converter.get(); - return true; - } -} - } #endif diff --git a/src/Formats/ProtobufReader.h b/src/Formats/ProtobufReader.h index b2a0714a57a..31d6f9a08e0 100644 --- a/src/Formats/ProtobufReader.h +++ b/src/Formats/ProtobufReader.h @@ -1,258 +1,72 @@ #pragma once -#include -#include -#include -#include - #if !defined(ARCADIA_BUILD) -# include "config_formats.h" +# include "config_formats.h" #endif #if USE_PROTOBUF -# include -# include -# include -# include "ProtobufColumnMatcher.h" +# include +# include -namespace google -{ -namespace protobuf -{ - class Descriptor; -} -} namespace DB { -class Arena; -class IAggregateFunction; class ReadBuffer; -using AggregateDataPtr = char *; -using AggregateFunctionPtr = std::shared_ptr; - - -/** Deserializes a protobuf, tries to cast data types if necessarily. - */ -class ProtobufReader : private boost::noncopyable -{ -public: - ProtobufReader(ReadBuffer & in_, const google::protobuf::Descriptor * message_type, const std::vector & column_names, const bool use_length_delimiters_); - ~ProtobufReader(); - - /// Should be called when we start reading a new message. - bool startMessage(); - - /// Ends reading a message. - void endMessage(bool ignore_errors = false); - - /// Reads the column index. - /// The function returns false if there are no more columns to read (call endMessage() in this case). - bool readColumnIndex(size_t & column_index); - - /// Reads a value which should be put to column at index received with readColumnIndex(). - /// The function returns false if there are no more values to read now (call readColumnIndex() in this case). - bool readNumber(Int8 & value) { return current_converter->readInt8(value); } - bool readNumber(UInt8 & value) { return current_converter->readUInt8(value); } - bool readNumber(Int16 & value) { return current_converter->readInt16(value); } - bool readNumber(UInt16 & value) { return current_converter->readUInt16(value); } - bool readNumber(Int32 & value) { return current_converter->readInt32(value); } - bool readNumber(UInt32 & value) { return current_converter->readUInt32(value); } - bool readNumber(Int64 & value) { return current_converter->readInt64(value); } - bool readNumber(UInt64 & value) { return current_converter->readUInt64(value); } - bool readNumber(Int128 & value) { return current_converter->readInt128(value); } - bool readNumber(UInt128 & value) { return current_converter->readUInt128(value); } - bool readNumber(Int256 & value) { return current_converter->readInt256(value); } - bool readNumber(UInt256 & value) { return current_converter->readUInt256(value); } - bool readNumber(Float32 & value) { return current_converter->readFloat32(value); } - bool readNumber(Float64 & value) { return current_converter->readFloat64(value); } - - bool readStringInto(PaddedPODArray & str) { return current_converter->readStringInto(str); } - - void prepareEnumMapping(const std::vector> & name_value_pairs) { current_converter->prepareEnumMapping8(name_value_pairs); } - void prepareEnumMapping(const std::vector> & name_value_pairs) { current_converter->prepareEnumMapping16(name_value_pairs); } - bool readEnum(Int8 & value) { return current_converter->readEnum8(value); } - bool readEnum(Int16 & value) { return current_converter->readEnum16(value); } - - bool readUUID(UUID & uuid) { return current_converter->readUUID(uuid); } - bool readDate(DayNum & date) { return current_converter->readDate(date); } - bool readDateTime(time_t & tm) { return current_converter->readDateTime(tm); } - bool readDateTime64(DateTime64 & tm, UInt32 scale) { return current_converter->readDateTime64(tm, scale); } - - bool readDecimal(Decimal32 & decimal, UInt32 precision, UInt32 scale) { return current_converter->readDecimal32(decimal, precision, scale); } - bool readDecimal(Decimal64 & decimal, UInt32 precision, UInt32 scale) { return current_converter->readDecimal64(decimal, precision, scale); } - bool readDecimal(Decimal128 & decimal, UInt32 precision, UInt32 scale) { return current_converter->readDecimal128(decimal, precision, scale); } - bool readDecimal(Decimal256 & decimal, UInt32 precision, UInt32 scale) { return current_converter->readDecimal256(decimal, precision, scale); } - - bool readAggregateFunction(const AggregateFunctionPtr & function, AggregateDataPtr place, Arena & arena) { return current_converter->readAggregateFunction(function, place, arena); } - - /// Call it after calling one of the read*() function to determine if there are more values available for reading. - bool ALWAYS_INLINE canReadMoreValues() const { return simple_reader.canReadMoreValues(); } - -private: - class SimpleReader - { - public: - SimpleReader(ReadBuffer & in_, const bool use_length_delimiters_); - bool startMessage(); - void endMessage(bool ignore_errors); - void startNestedMessage(); - void endNestedMessage(); - bool readFieldNumber(UInt32 & field_number); - bool readInt(Int64 & value); - bool readSInt(Int64 & value); - bool readUInt(UInt64 & value); - template bool readFixed(T & value); - bool readStringInto(PaddedPODArray & str); - - bool ALWAYS_INLINE canReadMoreValues() const { return cursor < field_end; } - - private: - void readBinary(void * data, size_t size); - void ignore(UInt64 num_bytes); - void moveCursorBackward(UInt64 num_bytes); - - UInt64 ALWAYS_INLINE readVarint() - { - char c; - in.readStrict(c); - UInt64 first_byte = static_cast(c); - ++cursor; - if (likely(!(c & 0x80))) - return first_byte; - return continueReadingVarint(first_byte); - } - - UInt64 continueReadingVarint(UInt64 first_byte); - void ignoreVarint(); - void ignoreGroup(); - [[noreturn]] void throwUnknownFormat() const; - - ReadBuffer & in; - Int64 cursor; - size_t current_message_level; - Int64 current_message_end; - std::vector parent_message_ends; - Int64 field_end; - Int64 last_string_pos; - const bool use_length_delimiters; - }; - - class IConverter - { - public: - virtual ~IConverter() = default; - virtual bool readStringInto(PaddedPODArray &) = 0; - virtual bool readInt8(Int8&) = 0; - virtual bool readUInt8(UInt8 &) = 0; - virtual bool readInt16(Int16 &) = 0; - virtual bool readUInt16(UInt16 &) = 0; - virtual bool readInt32(Int32 &) = 0; - virtual bool readUInt32(UInt32 &) = 0; - virtual bool readInt64(Int64 &) = 0; - virtual bool readUInt64(UInt64 &) = 0; - virtual bool readInt128(Int128 &) = 0; - virtual bool readUInt128(UInt128 &) = 0; - - virtual bool readInt256(Int256 &) = 0; - virtual bool readUInt256(UInt256 &) = 0; - - virtual bool readFloat32(Float32 &) = 0; - virtual bool readFloat64(Float64 &) = 0; - virtual void prepareEnumMapping8(const std::vector> &) = 0; - virtual void prepareEnumMapping16(const std::vector> &) = 0; - virtual bool readEnum8(Int8 &) = 0; - virtual bool readEnum16(Int16 &) = 0; - virtual bool readUUID(UUID &) = 0; - virtual bool readDate(DayNum &) = 0; - virtual bool readDateTime(time_t &) = 0; - virtual bool readDateTime64(DateTime64 &, UInt32) = 0; - virtual bool readDecimal32(Decimal32 &, UInt32, UInt32) = 0; - virtual bool readDecimal64(Decimal64 &, UInt32, UInt32) = 0; - virtual bool readDecimal128(Decimal128 &, UInt32, UInt32) = 0; - virtual bool readDecimal256(Decimal256 &, UInt32, UInt32) = 0; - virtual bool readAggregateFunction(const AggregateFunctionPtr &, AggregateDataPtr, Arena &) = 0; - }; - - class ConverterBaseImpl; - class ConverterFromString; - template class ConverterFromNumber; - class ConverterFromBool; - class ConverterFromEnum; - - struct ColumnMatcherTraits - { - struct FieldData - { - std::unique_ptr converter; - }; - struct MessageData - { - std::unordered_map*> field_number_to_field_map; - }; - }; - using Message = ProtobufColumnMatcher::Message; - using Field = ProtobufColumnMatcher::Field; - - void setTraitsDataAfterMatchingColumns(Message * message); - - template - std::unique_ptr createConverter(const google::protobuf::FieldDescriptor * field); - - SimpleReader simple_reader; - std::unique_ptr root_message; - Message* current_message = nullptr; - size_t current_field_index = 0; - IConverter* current_converter = nullptr; -}; - -} - -#else - -namespace DB -{ -class Arena; -class IAggregateFunction; -class ReadBuffer; -using AggregateDataPtr = char *; -using AggregateFunctionPtr = std::shared_ptr; +/// Utility class for reading in the Protobuf format. +/// Knows nothing about protobuf schemas, just provides useful functions to serialize data. class ProtobufReader { public: - bool startMessage() { return false; } - void endMessage() {} - bool readColumnIndex(size_t &) { return false; } - bool readNumber(Int8 &) { return false; } - bool readNumber(UInt8 &) { return false; } - bool readNumber(Int16 &) { return false; } - bool readNumber(UInt16 &) { return false; } - bool readNumber(Int32 &) { return false; } - bool readNumber(UInt32 &) { return false; } - bool readNumber(Int64 &) { return false; } - bool readNumber(UInt64 &) { return false; } - bool readNumber(Int128 &) { return false; } - bool readNumber(UInt128 &) { return false; } - bool readNumber(Int256 &) { return false; } - bool readNumber(UInt256 &) { return false; } - bool readNumber(Float32 &) { return false; } - bool readNumber(Float64 &) { return false; } - bool readStringInto(PaddedPODArray &) { return false; } - void prepareEnumMapping(const std::vector> &) {} - void prepareEnumMapping(const std::vector> &) {} - bool readEnum(Int8 &) { return false; } - bool readEnum(Int16 &) { return false; } - bool readUUID(UUID &) { return false; } - bool readDate(DayNum &) { return false; } - bool readDateTime(time_t &) { return false; } - bool readDateTime64(DateTime64 & /*tm*/, UInt32 /*scale*/) { return false; } - bool readDecimal(Decimal32 &, UInt32, UInt32) { return false; } - bool readDecimal(Decimal64 &, UInt32, UInt32) { return false; } - bool readDecimal(Decimal128 &, UInt32, UInt32) { return false; } - bool readDecimal(Decimal256 &, UInt32, UInt32) { return false; } - bool readAggregateFunction(const AggregateFunctionPtr &, AggregateDataPtr, Arena &) { return false; } - bool canReadMoreValues() const { return false; } + ProtobufReader(ReadBuffer & in_); + + void startMessage(bool with_length_delimiter_); + void endMessage(bool ignore_errors); + void startNestedMessage(); + void endNestedMessage(); + + bool readFieldNumber(int & field_number); + Int64 readInt(); + Int64 readSInt(); + UInt64 readUInt(); + template T readFixed(); + + void readString(String & str); + void readStringAndAppend(PaddedPODArray & str); + + bool eof() const { return in.eof(); } + +private: + void readBinary(void * data, size_t size); + void ignore(UInt64 num_bytes); + void ignoreAll(); + void moveCursorBackward(UInt64 num_bytes); + + UInt64 ALWAYS_INLINE readVarint() + { + char c; + in.readStrict(c); + UInt64 first_byte = static_cast(c); + ++cursor; + if (likely(!(c & 0x80))) + return first_byte; + return continueReadingVarint(first_byte); + } + + UInt64 continueReadingVarint(UInt64 first_byte); + void ignoreVarint(); + void ignoreGroup(); + [[noreturn]] void throwUnknownFormat() const; + + ReadBuffer & in; + Int64 cursor = 0; + bool root_message_has_length_delimiter = false; + size_t current_message_level = 0; + Int64 current_message_end = 0; + std::vector parent_message_ends; + int field_number = 0; + int next_field_number = 0; + Int64 field_end = 0; }; } diff --git a/src/Formats/ProtobufSerializer.cpp b/src/Formats/ProtobufSerializer.cpp new file mode 100644 index 00000000000..82149460773 --- /dev/null +++ b/src/Formats/ProtobufSerializer.cpp @@ -0,0 +1,2921 @@ +#include + +#if USE_PROTOBUF +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int NO_COLUMNS_SERIALIZED_TO_PROTOBUF_FIELDS; + extern const int MULTIPLE_COLUMNS_SERIALIZED_TO_SAME_PROTOBUF_FIELD; + extern const int NO_COLUMN_SERIALIZED_TO_REQUIRED_PROTOBUF_FIELD; + extern const int DATA_TYPE_INCOMPATIBLE_WITH_PROTOBUF_FIELD; + extern const int PROTOBUF_FIELD_NOT_REPEATED; + extern const int PROTOBUF_BAD_CAST; + extern const int LOGICAL_ERROR; +} + +namespace +{ + using FieldDescriptor = google::protobuf::FieldDescriptor; + using MessageDescriptor = google::protobuf::Descriptor; + using FieldTypeId = google::protobuf::FieldDescriptor::Type; + + + /// Compares column's name with protobuf field's name. + /// This comparison is case-insensitive and ignores the difference between '.' and '_' + struct ColumnNameWithProtobufFieldNameComparator + { + static bool equals(char c1, char c2) + { + return convertChar(c1) == convertChar(c2); + } + + static bool equals(const std::string_view & s1, const std::string_view & s2) + { + return (s1.length() == s2.length()) + && std::equal(s1.begin(), s1.end(), s2.begin(), [](char c1, char c2) { return convertChar(c1) == convertChar(c2); }); + } + + static bool less(const std::string_view & s1, const std::string_view & s2) + { + return std::lexicographical_compare(s1.begin(), s1.end(), s2.begin(), s2.end(), [](char c1, char c2) { return convertChar(c1) < convertChar(c2); }); + } + + static bool startsWith(const std::string_view & s1, const std::string_view & s2) + { + return (s1.length() >= s2.length()) && equals(s1.substr(0, s2.length()), s2); + } + + static char convertChar(char c) + { + c = tolower(c); + if (c == '.') + c = '_'; + return c; + } + }; + + + // Should we omit null values (zero for numbers / empty string for strings) while storing them. + bool shouldSkipZeroOrEmpty(const FieldDescriptor & field_descriptor) + { + if (!field_descriptor.is_optional()) + return false; + if (field_descriptor.containing_type()->options().map_entry()) + return false; + return field_descriptor.message_type() || (field_descriptor.file()->syntax() == google::protobuf::FileDescriptor::SYNTAX_PROTO3); + } + + // Should we pack repeated values while storing them. + bool shouldPackRepeated(const FieldDescriptor & field_descriptor) + { + if (!field_descriptor.is_repeated()) + return false; + switch (field_descriptor.type()) + { + case FieldTypeId::TYPE_INT32: + case FieldTypeId::TYPE_UINT32: + case FieldTypeId::TYPE_SINT32: + case FieldTypeId::TYPE_INT64: + case FieldTypeId::TYPE_UINT64: + case FieldTypeId::TYPE_SINT64: + case FieldTypeId::TYPE_FIXED32: + case FieldTypeId::TYPE_SFIXED32: + case FieldTypeId::TYPE_FIXED64: + case FieldTypeId::TYPE_SFIXED64: + case FieldTypeId::TYPE_FLOAT: + case FieldTypeId::TYPE_DOUBLE: + case FieldTypeId::TYPE_BOOL: + case FieldTypeId::TYPE_ENUM: + break; + default: + return false; + } + if (field_descriptor.options().has_packed()) + return field_descriptor.options().packed(); + return field_descriptor.file()->syntax() == google::protobuf::FileDescriptor::SYNTAX_PROTO3; + } + + + struct ProtobufReaderOrWriter + { + ProtobufReaderOrWriter(ProtobufReader & reader_) : reader(&reader_) {} // NOLINT(google-explicit-constructor) + ProtobufReaderOrWriter(ProtobufWriter & writer_) : writer(&writer_) {} // NOLINT(google-explicit-constructor) + ProtobufReader * const reader = nullptr; + ProtobufWriter * const writer = nullptr; + }; + + + /// Base class for all serializers which serialize a single value. + class ProtobufSerializerSingleValue : public ProtobufSerializer + { + protected: + ProtobufSerializerSingleValue(const FieldDescriptor & field_descriptor_, const ProtobufReaderOrWriter & reader_or_writer_) + : field_descriptor(field_descriptor_) + , field_typeid(field_descriptor_.type()) + , field_tag(field_descriptor.number()) + , reader(reader_or_writer_.reader) + , writer(reader_or_writer_.writer) + , skip_zero_or_empty(shouldSkipZeroOrEmpty(field_descriptor)) + { + } + + void setColumns(const ColumnPtr * columns, [[maybe_unused]] size_t num_columns) override + { + assert(num_columns == 1); + column = columns[0]; + } + + void setColumns(const MutableColumnPtr * columns, [[maybe_unused]] size_t num_columns) override + { + assert(num_columns == 1); + column = columns[0]->getPtr(); + } + + template + void writeInt(NumberType value) + { + auto casted = castNumber(value); + if (casted || !skip_zero_or_empty) + writer->writeInt(field_tag, casted); + } + + template + void writeSInt(NumberType value) + { + auto casted = castNumber(value); + if (casted || !skip_zero_or_empty) + writer->writeSInt(field_tag, casted); + } + + template + void writeUInt(NumberType value) + { + auto casted = castNumber(value); + if (casted || !skip_zero_or_empty) + writer->writeUInt(field_tag, casted); + } + + template + void writeFixed(NumberType value) + { + auto casted = castNumber(value); + if (casted || !skip_zero_or_empty) + writer->writeFixed(field_tag, casted); + } + + Int64 readInt() { return reader->readInt(); } + Int64 readSInt() { return reader->readSInt(); } + UInt64 readUInt() { return reader->readUInt(); } + + template + FieldType readFixed() + { + return reader->readFixed(); + } + + void writeStr(const std::string_view & str) + { + if (!str.empty() || !skip_zero_or_empty) + writer->writeString(field_tag, str); + } + + void readStr(String & str) { reader->readString(str); } + void readStrAndAppend(PaddedPODArray & str) { reader->readStringAndAppend(str); } + + template + DestType parseFromStr(const std::string_view & str) const + { + try + { + DestType result; + ReadBufferFromMemory buf(str.data(), str.length()); + readText(result, buf); + return result; + } + catch (...) + { + cannotConvertValue(str, "String", TypeName::get()); + } + } + + template + DestType castNumber(SrcType value) const + { + if constexpr (std::is_same_v) + return value; + DestType result; + try + { + /// TODO: use accurate::convertNumeric() maybe? + result = boost::numeric_cast(value); + } + catch (boost::numeric::bad_numeric_cast &) + { + cannotConvertValue(toString(value), TypeName::get(), TypeName::get()); + } + return result; + } + + [[noreturn]] void cannotConvertValue(const std::string_view & src_value, const std::string_view & src_type_name, const std::string_view & dest_type_name) const + { + throw Exception( + "Could not convert value '" + String{src_value} + "' from type " + String{src_type_name} + " to type " + String{dest_type_name} + + " while " + (reader ? "reading" : "writing") + " field " + field_descriptor.name(), + ErrorCodes::PROTOBUF_BAD_CAST); + } + + const FieldDescriptor & field_descriptor; + const FieldTypeId field_typeid; + const int field_tag; + ProtobufReader * const reader; + ProtobufWriter * const writer; + ColumnPtr column; + + private: + const bool skip_zero_or_empty; + }; + + + /// Serializes any ColumnVector to a field of any type except TYPE_MESSAGE, TYPE_GROUP. + /// NumberType must be one of the following types: Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, + /// Int128, UInt128, Int256, UInt256, Float32, Float64. + /// And the field's type cannot be TYPE_ENUM if NumberType is Float32 or Float64. + template + class ProtobufSerializerNumber : public ProtobufSerializerSingleValue + { + public: + using ColumnType = ColumnVector; + + ProtobufSerializerNumber(const FieldDescriptor & field_descriptor_, const ProtobufReaderOrWriter & reader_or_writer_) + : ProtobufSerializerSingleValue(field_descriptor_, reader_or_writer_) + { + setFunctions(); + } + + void writeRow(size_t row_num) override + { + const auto & column_vector = assert_cast(*column); + write_function(column_vector.getElement(row_num)); + } + + void readRow(size_t row_num) override + { + NumberType value = read_function(); + auto & column_vector = assert_cast(column->assumeMutableRef()); + if (row_num < column_vector.size()) + column_vector.getElement(row_num) = value; + else + column_vector.insertValue(value); + } + + void insertDefaults(size_t row_num) override + { + auto & column_vector = assert_cast(column->assumeMutableRef()); + if (row_num < column_vector.size()) + return; + column_vector.insertValue(getDefaultNumber()); + } + + private: + void setFunctions() + { + switch (field_typeid) + { + case FieldTypeId::TYPE_INT32: + { + write_function = [this](NumberType value) { writeInt(value); }; + read_function = [this]() -> NumberType { return castNumber(readInt()); }; + default_function = [this]() -> NumberType { return castNumber(field_descriptor.default_value_int32()); }; + break; + } + + case FieldTypeId::TYPE_SINT32: + { + write_function = [this](NumberType value) { writeSInt(value); }; + read_function = [this]() -> NumberType { return castNumber(readSInt()); }; + default_function = [this]() -> NumberType { return castNumber(field_descriptor.default_value_int32()); }; + break; + } + + case FieldTypeId::TYPE_UINT32: + { + write_function = [this](NumberType value) { writeUInt(value); }; + read_function = [this]() -> NumberType { return castNumber(readUInt()); }; + default_function = [this]() -> NumberType { return castNumber(field_descriptor.default_value_uint32()); }; + break; + } + + case FieldTypeId::TYPE_INT64: + { + write_function = [this](NumberType value) { writeInt(value); }; + read_function = [this]() -> NumberType { return castNumber(readInt()); }; + default_function = [this]() -> NumberType { return castNumber(field_descriptor.default_value_int64()); }; + break; + } + + case FieldTypeId::TYPE_SINT64: + { + write_function = [this](NumberType value) { writeSInt(value); }; + read_function = [this]() -> NumberType { return castNumber(readSInt()); }; + default_function = [this]() -> NumberType { return castNumber(field_descriptor.default_value_int64()); }; + break; + } + + case FieldTypeId::TYPE_UINT64: + { + write_function = [this](NumberType value) { writeUInt(value); }; + read_function = [this]() -> NumberType { return castNumber(readUInt()); }; + default_function = [this]() -> NumberType { return castNumber(field_descriptor.default_value_uint64()); }; + break; + } + + case FieldTypeId::TYPE_FIXED32: + { + write_function = [this](NumberType value) { writeFixed(value); }; + read_function = [this]() -> NumberType { return castNumber(readFixed()); }; + default_function = [this]() -> NumberType { return castNumber(field_descriptor.default_value_uint32()); }; + break; + } + + case FieldTypeId::TYPE_SFIXED32: + { + write_function = [this](NumberType value) { writeFixed(value); }; + read_function = [this]() -> NumberType { return castNumber(readFixed()); }; + default_function = [this]() -> NumberType { return castNumber(field_descriptor.default_value_int32()); }; + break; + } + + case FieldTypeId::TYPE_FIXED64: + { + write_function = [this](NumberType value) { writeFixed(value); }; + read_function = [this]() -> NumberType { return castNumber(readFixed()); }; + default_function = [this]() -> NumberType { return castNumber(field_descriptor.default_value_uint64()); }; + break; + } + + case FieldTypeId::TYPE_SFIXED64: + { + write_function = [this](NumberType value) { writeFixed(value); }; + read_function = [this]() -> NumberType { return castNumber(readFixed()); }; + default_function = [this]() -> NumberType { return castNumber(field_descriptor.default_value_int64()); }; + break; + } + + case FieldTypeId::TYPE_FLOAT: + { + write_function = [this](NumberType value) { writeFixed(value); }; + read_function = [this]() -> NumberType { return castNumber(readFixed()); }; + default_function = [this]() -> NumberType { return castNumber(field_descriptor.default_value_float()); }; + break; + } + + case FieldTypeId::TYPE_DOUBLE: + { + write_function = [this](NumberType value) { writeFixed(value); }; + read_function = [this]() -> NumberType { return castNumber(readFixed()); }; + default_function = [this]() -> NumberType { return castNumber(field_descriptor.default_value_double()); }; + break; + } + + case FieldTypeId::TYPE_BOOL: + { + write_function = [this](NumberType value) + { + if (value == 0) + writeUInt(0); + else if (value == 1) + writeUInt(1); + else + cannotConvertValue(toString(value), TypeName::get(), field_descriptor.type_name()); + }; + + read_function = [this]() -> NumberType + { + UInt64 u64 = readUInt(); + if (u64 < 2) + return static_cast(u64); + else + cannotConvertValue(toString(u64), field_descriptor.type_name(), TypeName::get()); + }; + + default_function = [this]() -> NumberType { return static_cast(field_descriptor.default_value_bool()); }; + break; + } + + case FieldTypeId::TYPE_STRING: + case FieldTypeId::TYPE_BYTES: + { + write_function = [this](NumberType value) + { + WriteBufferFromString buf{text_buffer}; + writeText(value, buf); + buf.finalize(); + writeStr(text_buffer); + }; + + read_function = [this]() -> NumberType + { + readStr(text_buffer); + return parseFromStr(text_buffer); + }; + + default_function = [this]() -> NumberType { return parseFromStr(field_descriptor.default_value_string()); }; + break; + } + + case FieldTypeId::TYPE_ENUM: + { + if (std::is_floating_point_v) + failedToSetFunctions(); + + write_function = [this](NumberType value) + { + int number = castNumber(value); + checkProtobufEnumValue(number); + writeInt(number); + }; + + read_function = [this]() -> NumberType { return castNumber(readInt()); }; + default_function = [this]() -> NumberType { return castNumber(field_descriptor.default_value_enum()->number()); }; + break; + } + + default: + failedToSetFunctions(); + } + } + + [[noreturn]] void failedToSetFunctions() const + { + throw Exception( + "The field " + quoteString(field_descriptor.full_name()) + " has an incompatible type " + field_descriptor.type_name() + + " for serialization of the data type " + quoteString(TypeName::get()), + ErrorCodes::DATA_TYPE_INCOMPATIBLE_WITH_PROTOBUF_FIELD); + } + + NumberType getDefaultNumber() + { + if (!default_number) + default_number = default_function(); + return *default_number; + } + + void checkProtobufEnumValue(int value) const + { + const auto * enum_value_descriptor = field_descriptor.enum_type()->FindValueByNumber(value); + if (!enum_value_descriptor) + cannotConvertValue(toString(value), TypeName::get(), field_descriptor.type_name()); + } + + protected: + std::function write_function; + std::function read_function; + std::function default_function; + String text_buffer; + + private: + std::optional default_number; + }; + + + /// Serializes ColumnString or ColumnFixedString to a field of any type except TYPE_MESSAGE, TYPE_GROUP. + template + class ProtobufSerializerString : public ProtobufSerializerSingleValue + { + public: + using ColumnType = std::conditional_t; + using StringDataType = std::conditional_t; + + ProtobufSerializerString( + const StringDataType & string_data_type_, + const google::protobuf::FieldDescriptor & field_descriptor_, + const ProtobufReaderOrWriter & reader_or_writer_) + : ProtobufSerializerSingleValue(field_descriptor_, reader_or_writer_) + { + static_assert(is_fixed_string, "This constructor for FixedString only"); + n = string_data_type_.getN(); + setFunctions(); + prepareEnumMapping(); + } + + ProtobufSerializerString( + const google::protobuf::FieldDescriptor & field_descriptor_, const ProtobufReaderOrWriter & reader_or_writer_) + : ProtobufSerializerSingleValue(field_descriptor_, reader_or_writer_) + { + static_assert(!is_fixed_string, "This constructor for String only"); + setFunctions(); + prepareEnumMapping(); + } + + void writeRow(size_t row_num) override + { + const auto & column_string = assert_cast(*column); + write_function(std::string_view{column_string.getDataAt(row_num)}); + } + + void readRow(size_t row_num) override + { + auto & column_string = assert_cast(column->assumeMutableRef()); + const size_t old_size = column_string.size(); + typename ColumnType::Chars & data = column_string.getChars(); + const size_t old_data_size = data.size(); + + if (row_num < old_size) + { + text_buffer.clear(); + read_function(text_buffer); + } + else + { + try + { + read_function(data); + } + catch (...) + { + data.resize_assume_reserved(old_data_size); + throw; + } + } + + if constexpr (is_fixed_string) + { + if (row_num < old_size) + { + ColumnFixedString::alignStringLength(text_buffer, n, 0); + memcpy(data.data() + row_num * n, text_buffer.data(), n); + } + else + ColumnFixedString::alignStringLength(data, n, old_data_size); + } + else + { + if (row_num < old_size) + { + if (row_num != old_size - 1) + throw Exception("Cannot replace a string in the middle of ColumnString", ErrorCodes::LOGICAL_ERROR); + column_string.popBack(1); + } + try + { + data.push_back(0 /* terminating zero */); + column_string.getOffsets().push_back(data.size()); + } + catch (...) + { + data.resize_assume_reserved(old_data_size); + column_string.getOffsets().resize_assume_reserved(old_size); + throw; + } + } + } + + void insertDefaults(size_t row_num) override + { + auto & column_string = assert_cast(column->assumeMutableRef()); + const size_t old_size = column_string.size(); + if (row_num < old_size) + return; + + const auto & default_str = getDefaultString(); + typename ColumnType::Chars & data = column_string.getChars(); + const size_t old_data_size = data.size(); + try + { + data.insert(default_str.data(), default_str.data() + default_str.size()); + } + catch (...) + { + data.resize_assume_reserved(old_data_size); + throw; + } + + if constexpr (!is_fixed_string) + { + try + { + data.push_back(0 /* terminating zero */); + column_string.getOffsets().push_back(data.size()); + } + catch (...) + { + data.resize_assume_reserved(old_data_size); + column_string.getOffsets().resize_assume_reserved(old_size); + throw; + } + } + } + + private: + void setFunctions() + { + switch (field_typeid) + { + case FieldTypeId::TYPE_INT32: + { + write_function = [this](const std::string_view & str) { writeInt(parseFromStr(str)); }; + read_function = [this](PaddedPODArray & str) { toStringAppend(readInt(), str); }; + default_function = [this]() -> String { return toString(field_descriptor.default_value_int32()); }; + break; + } + + case FieldTypeId::TYPE_SINT32: + { + write_function = [this](const std::string_view & str) { writeSInt(parseFromStr(str)); }; + read_function = [this](PaddedPODArray & str) { toStringAppend(readSInt(), str); }; + default_function = [this]() -> String { return toString(field_descriptor.default_value_int32()); }; + break; + } + + case FieldTypeId::TYPE_UINT32: + { + write_function = [this](const std::string_view & str) { writeUInt(parseFromStr(str)); }; + read_function = [this](PaddedPODArray & str) { toStringAppend(readUInt(), str); }; + default_function = [this]() -> String { return toString(field_descriptor.default_value_uint32()); }; + break; + } + + case FieldTypeId::TYPE_INT64: + { + write_function = [this](const std::string_view & str) { writeInt(parseFromStr(str)); }; + read_function = [this](PaddedPODArray & str) { toStringAppend(readInt(), str); }; + default_function = [this]() -> String { return toString(field_descriptor.default_value_int64()); }; + break; + } + + case FieldTypeId::TYPE_SINT64: + { + write_function = [this](const std::string_view & str) { writeSInt(parseFromStr(str)); }; + read_function = [this](PaddedPODArray & str) { toStringAppend(readSInt(), str); }; + default_function = [this]() -> String { return toString(field_descriptor.default_value_int64()); }; + break; + } + + case FieldTypeId::TYPE_UINT64: + { + write_function = [this](const std::string_view & str) { writeUInt(parseFromStr(str)); }; + read_function = [this](PaddedPODArray & str) { toStringAppend(readUInt(), str); }; + default_function = [this]() -> String { return toString(field_descriptor.default_value_uint64()); }; + break; + } + + case FieldTypeId::TYPE_FIXED32: + { + write_function = [this](const std::string_view & str) { writeFixed(parseFromStr(str)); }; + read_function = [this](PaddedPODArray & str) { toStringAppend(readFixed(), str); }; + default_function = [this]() -> String { return toString(field_descriptor.default_value_uint32()); }; + break; + } + + case FieldTypeId::TYPE_SFIXED32: + { + write_function = [this](const std::string_view & str) { writeFixed(parseFromStr(str)); }; + read_function = [this](PaddedPODArray & str) { toStringAppend(readFixed(), str); }; + default_function = [this]() -> String { return toString(field_descriptor.default_value_int32()); }; + break; + } + + case FieldTypeId::TYPE_FIXED64: + { + write_function = [this](const std::string_view & str) { writeFixed(parseFromStr(str)); }; + read_function = [this](PaddedPODArray & str) { toStringAppend(readFixed(), str); }; + default_function = [this]() -> String { return toString(field_descriptor.default_value_uint64()); }; + break; + } + + case FieldTypeId::TYPE_SFIXED64: + { + write_function = [this](const std::string_view & str) { writeFixed(parseFromStr(str)); }; + read_function = [this](PaddedPODArray & str) { toStringAppend(readFixed(), str); }; + default_function = [this]() -> String { return toString(field_descriptor.default_value_int64()); }; + break; + } + + case FieldTypeId::TYPE_FLOAT: + { + write_function = [this](const std::string_view & str) { writeFixed(parseFromStr(str)); }; + read_function = [this](PaddedPODArray & str) { toStringAppend(readFixed(), str); }; + default_function = [this]() -> String { return toString(field_descriptor.default_value_float()); }; + break; + } + + case FieldTypeId::TYPE_DOUBLE: + { + write_function = [this](const std::string_view & str) { writeFixed(parseFromStr(str)); }; + read_function = [this](PaddedPODArray & str) { toStringAppend(readFixed(), str); }; + default_function = [this]() -> String { return toString(field_descriptor.default_value_double()); }; + break; + } + + case FieldTypeId::TYPE_BOOL: + { + write_function = [this](const std::string_view & str) + { + if (str == "true") + writeUInt(1); + else if (str == "false") + writeUInt(0); + else + cannotConvertValue(str, "String", field_descriptor.type_name()); + }; + + read_function = [this](PaddedPODArray & str) + { + UInt64 u64 = readUInt(); + if (u64 < 2) + { + std::string_view ref(u64 ? "true" : "false"); + str.insert(ref.data(), ref.data() + ref.length()); + } + else + cannotConvertValue(toString(u64), field_descriptor.type_name(), "String"); + }; + + default_function = [this]() -> String + { + return field_descriptor.default_value_bool() ? "true" : "false"; + }; + break; + } + + case FieldTypeId::TYPE_STRING: + case FieldTypeId::TYPE_BYTES: + { + write_function = [this](const std::string_view & str) { writeStr(str); }; + read_function = [this](PaddedPODArray & str) { readStrAndAppend(str); }; + default_function = [this]() -> String { return field_descriptor.default_value_string(); }; + break; + } + + case FieldTypeId::TYPE_ENUM: + { + write_function = [this](const std::string_view & str) { writeInt(stringToProtobufEnumValue(str)); }; + read_function = [this](PaddedPODArray & str) { protobufEnumValueToStringAppend(readInt(), str); }; + default_function = [this]() -> String { return field_descriptor.default_value_enum()->name(); }; + break; + } + + default: + failedToSetFunctions(); + } + } + + [[noreturn]] void failedToSetFunctions() + { + throw Exception( + "The field " + quoteString(field_descriptor.full_name()) + " has an incompatible type " + field_descriptor.type_name() + + " for serialization of the data type " + quoteString(is_fixed_string ? "FixedString" : "String"), + ErrorCodes::DATA_TYPE_INCOMPATIBLE_WITH_PROTOBUF_FIELD); + } + + const PaddedPODArray & getDefaultString() + { + if (!default_string) + { + PaddedPODArray arr; + auto str = default_function(); + arr.insert(str.data(), str.data() + str.size()); + if constexpr (is_fixed_string) + ColumnFixedString::alignStringLength(arr, n, 0); + default_string = std::move(arr); + } + return *default_string; + } + + template + void toStringAppend(NumberType value, PaddedPODArray & str) + { + WriteBufferFromVector buf{str, WriteBufferFromVector>::AppendModeTag{}}; + writeText(value, buf); + } + + void prepareEnumMapping() + { + if ((field_typeid == google::protobuf::FieldDescriptor::TYPE_ENUM) && writer) + { + const auto & enum_descriptor = *field_descriptor.enum_type(); + for (int i = 0; i != enum_descriptor.value_count(); ++i) + { + const auto & enum_value_descriptor = *enum_descriptor.value(i); + string_to_protobuf_enum_value_map.emplace(enum_value_descriptor.name(), enum_value_descriptor.number()); + } + } + } + + int stringToProtobufEnumValue(const std::string_view & str) const + { + auto it = string_to_protobuf_enum_value_map.find(str); + if (it == string_to_protobuf_enum_value_map.end()) + cannotConvertValue(str, "String", field_descriptor.type_name()); + return it->second; + } + + std::string_view protobufEnumValueToString(int value) const + { + const auto * enum_value_descriptor = field_descriptor.enum_type()->FindValueByNumber(value); + if (!enum_value_descriptor) + cannotConvertValue(toString(value), field_descriptor.type_name(), "String"); + return enum_value_descriptor->name(); + } + + void protobufEnumValueToStringAppend(int value, PaddedPODArray & str) const + { + auto name = protobufEnumValueToString(value); + str.insert(name.data(), name.data() + name.length()); + } + + size_t n = 0; + std::function write_function; + std::function &)> read_function; + std::function default_function; + std::unordered_map string_to_protobuf_enum_value_map; + PaddedPODArray text_buffer; + std::optional> default_string; + }; + + + /// Serializes ColumnVector containing enum values to a field of any type + /// except TYPE_MESSAGE, TYPE_GROUP, TYPE_FLOAT, TYPE_DOUBLE, TYPE_BOOL. + /// NumberType can be either Int8 or Int16. + template + class ProtobufSerializerEnum : public ProtobufSerializerNumber + { + public: + using ColumnType = ColumnVector; + using EnumDataType = DataTypeEnum; + using BaseClass = ProtobufSerializerNumber; + + ProtobufSerializerEnum( + const std::shared_ptr & enum_data_type_, + const FieldDescriptor & field_descriptor_, + const ProtobufReaderOrWriter & reader_or_writer_) + : BaseClass(field_descriptor_, reader_or_writer_), enum_data_type(enum_data_type_) + { + assert(enum_data_type); + setFunctions(); + prepareEnumMapping(); + } + + private: + void setFunctions() + { + switch (this->field_typeid) + { + case FieldTypeId::TYPE_INT32: + case FieldTypeId::TYPE_SINT32: + case FieldTypeId::TYPE_UINT32: + case FieldTypeId::TYPE_INT64: + case FieldTypeId::TYPE_SINT64: + case FieldTypeId::TYPE_UINT64: + case FieldTypeId::TYPE_FIXED32: + case FieldTypeId::TYPE_SFIXED32: + case FieldTypeId::TYPE_FIXED64: + case FieldTypeId::TYPE_SFIXED64: + { + auto base_read_function = this->read_function; + this->read_function = [this, base_read_function]() -> NumberType + { + NumberType value = base_read_function(); + checkEnumDataTypeValue(value); + return value; + }; + + auto base_default_function = this->default_function; + this->default_function = [this, base_default_function]() -> NumberType + { + auto value = base_default_function(); + checkEnumDataTypeValue(value); + return value; + }; + break; + } + + case FieldTypeId::TYPE_STRING: + case FieldTypeId::TYPE_BYTES: + { + this->write_function = [this](NumberType value) + { + writeStr(enumDataTypeValueToString(value)); + }; + + this->read_function = [this]() -> NumberType + { + readStr(this->text_buffer); + return stringToEnumDataTypeValue(this->text_buffer); + }; + + this->default_function = [this]() -> NumberType + { + return stringToEnumDataTypeValue(this->field_descriptor.default_value_string()); + }; + break; + } + + case FieldTypeId::TYPE_ENUM: + { + this->write_function = [this](NumberType value) { writeInt(enumDataTypeValueToProtobufEnumValue(value)); }; + this->read_function = [this]() -> NumberType { return protobufEnumValueToEnumDataTypeValue(readInt()); }; + this->default_function = [this]() -> NumberType { return protobufEnumValueToEnumDataTypeValue(this->field_descriptor.default_value_enum()->number()); }; + break; + } + + default: + failedToSetFunctions(); + } + } + + [[noreturn]] void failedToSetFunctions() + { + throw Exception( + "The field " + quoteString(this->field_descriptor.full_name()) + " has an incompatible type " + this->field_descriptor.type_name() + + " for serialization of the data type " + quoteString(enum_data_type->getName()), + ErrorCodes::DATA_TYPE_INCOMPATIBLE_WITH_PROTOBUF_FIELD); + } + + void checkEnumDataTypeValue(NumberType value) + { + enum_data_type->findByValue(value); /// Throws an exception if the value isn't defined in the DataTypeEnum. + } + + std::string_view enumDataTypeValueToString(NumberType value) const { return std::string_view{enum_data_type->getNameForValue(value)}; } + NumberType stringToEnumDataTypeValue(const String & str) const { return enum_data_type->getValue(str); } + + void prepareEnumMapping() + { + if (this->field_typeid != FieldTypeId::TYPE_ENUM) + return; + + const auto & enum_descriptor = *this->field_descriptor.enum_type(); + + /// We have two mappings: + /// enum_data_type: "string->NumberType" and protobuf_enum: string->int". + /// And here we want to make from those two mapping a new mapping "NumberType->int" (if we're writing protobuf data), + /// or "int->NumberType" (if we're reading protobuf data). + + auto add_to_mapping = [&](NumberType enum_data_type_value, int protobuf_enum_value) + { + if (this->writer) + enum_data_type_value_to_protobuf_enum_value_map.emplace(enum_data_type_value, protobuf_enum_value); + else + protobuf_enum_value_to_enum_data_type_value_map.emplace(protobuf_enum_value, enum_data_type_value); + }; + + auto iless = [](const std::string_view & s1, const std::string_view & s2) { return ColumnNameWithProtobufFieldNameComparator::less(s1, s2); }; + boost::container::flat_map string_to_protobuf_enum_value_map; + typename decltype(string_to_protobuf_enum_value_map)::sequence_type string_to_protobuf_enum_value_seq; + for (int i : ext::range(enum_descriptor.value_count())) + string_to_protobuf_enum_value_seq.emplace_back(enum_descriptor.value(i)->name(), enum_descriptor.value(i)->number()); + string_to_protobuf_enum_value_map.adopt_sequence(std::move(string_to_protobuf_enum_value_seq)); + + std::vector not_found_by_name_values; + not_found_by_name_values.reserve(enum_data_type->getValues().size()); + + /// Find mapping between enum_data_type and protobuf_enum by name (case insensitively), + /// i.e. we add to the mapping + /// NumberType(enum_data_type) -> "NAME"(enum_data_type) -> + /// -> "NAME"(protobuf_enum, same name) -> int(protobuf_enum) + for (const auto & [name, value] : enum_data_type->getValues()) + { + auto it = string_to_protobuf_enum_value_map.find(name); + if (it != string_to_protobuf_enum_value_map.end()) + add_to_mapping(value, it->second); + else + not_found_by_name_values.push_back(value); + } + + if (!not_found_by_name_values.empty()) + { + /// Find mapping between two enum_data_type and protobuf_enum by value. + /// If the same value has different names in enum_data_type and protobuf_enum + /// we can still add it to our mapping, i.e. we add to the mapping + /// NumberType(enum_data_type) -> int(protobuf_enum, same value) + for (NumberType value : not_found_by_name_values) + { + if (enum_descriptor.FindValueByNumber(value)) + add_to_mapping(value, value); + } + } + + size_t num_mapped_values = this->writer ? enum_data_type_value_to_protobuf_enum_value_map.size() + : protobuf_enum_value_to_enum_data_type_value_map.size(); + + if (!num_mapped_values && !enum_data_type->getValues().empty() && enum_descriptor.value_count()) + { + throw Exception( + "Couldn't find mapping between data type " + enum_data_type->getName() + " and the enum " + quoteString(enum_descriptor.full_name()) + + " in the protobuf schema", + ErrorCodes::DATA_TYPE_INCOMPATIBLE_WITH_PROTOBUF_FIELD); + } + } + + int enumDataTypeValueToProtobufEnumValue(NumberType value) const + { + auto it = enum_data_type_value_to_protobuf_enum_value_map.find(value); + if (it == enum_data_type_value_to_protobuf_enum_value_map.end()) + cannotConvertValue(toString(value), enum_data_type->getName(), this->field_descriptor.type_name()); + return it->second; + } + + NumberType protobufEnumValueToEnumDataTypeValue(int value) const + { + auto it = protobuf_enum_value_to_enum_data_type_value_map.find(value); + if (it == protobuf_enum_value_to_enum_data_type_value_map.end()) + cannotConvertValue(toString(value), this->field_descriptor.type_name(), enum_data_type->getName()); + return it->second; + } + + Int64 readInt() { return ProtobufSerializerSingleValue::readInt(); } + void writeInt(Int64 value) { ProtobufSerializerSingleValue::writeInt(value); } + void writeStr(const std::string_view & str) { ProtobufSerializerSingleValue::writeStr(str); } + void readStr(String & str) { ProtobufSerializerSingleValue::readStr(str); } + [[noreturn]] void cannotConvertValue(const std::string_view & src_value, const std::string_view & src_type_name, const std::string_view & dest_type_name) const { ProtobufSerializerSingleValue::cannotConvertValue(src_value, src_type_name, dest_type_name); } + + const std::shared_ptr enum_data_type; + std::unordered_map enum_data_type_value_to_protobuf_enum_value_map; + std::unordered_map protobuf_enum_value_to_enum_data_type_value_map; + }; + + + /// Serializes a ColumnDecimal to any field except TYPE_MESSAGE, TYPE_GROUP, TYPE_ENUM. + /// DecimalType must be one of the following types: Decimal32, Decimal64, Decimal128, Decimal256, DateTime64. + template + class ProtobufSerializerDecimal : public ProtobufSerializerSingleValue + { + public: + using ColumnType = ColumnDecimal; + + ProtobufSerializerDecimal( + const DataTypeDecimalBase & decimal_data_type_, + const FieldDescriptor & field_descriptor_, + const ProtobufReaderOrWriter & reader_or_writer_) + : ProtobufSerializerSingleValue(field_descriptor_, reader_or_writer_) + , precision(decimal_data_type_.getPrecision()) + , scale(decimal_data_type_.getScale()) + { + setFunctions(); + } + + void writeRow(size_t row_num) override + { + const auto & column_decimal = assert_cast(*column); + write_function(column_decimal.getElement(row_num)); + } + + void readRow(size_t row_num) override + { + DecimalType decimal = read_function(); + auto & column_decimal = assert_cast(column->assumeMutableRef()); + if (row_num < column_decimal.size()) + column_decimal.getElement(row_num) = decimal; + else + column_decimal.insertValue(decimal); + } + + void insertDefaults(size_t row_num) override + { + auto & column_decimal = assert_cast(column->assumeMutableRef()); + if (row_num < column_decimal.size()) + return; + column_decimal.insertValue(getDefaultDecimal()); + } + + private: + void setFunctions() + { + switch (field_typeid) + { + case FieldTypeId::TYPE_INT32: + { + write_function = [this](const DecimalType & decimal) { writeInt(decimalToNumber(decimal)); }; + read_function = [this]() -> DecimalType { return numberToDecimal(readInt()); }; + default_function = [this]() -> DecimalType { return numberToDecimal(field_descriptor.default_value_int32()); }; + break; + } + + case FieldTypeId::TYPE_SINT32: + { + write_function = [this](const DecimalType & decimal) { writeSInt(decimalToNumber(decimal)); }; + read_function = [this]() -> DecimalType { return numberToDecimal(readSInt()); }; + default_function = [this]() -> DecimalType { return numberToDecimal(field_descriptor.default_value_int32()); }; + break; + } + + case FieldTypeId::TYPE_UINT32: + { + write_function = [this](const DecimalType & decimal) { writeUInt(decimalToNumber(decimal)); }; + read_function = [this]() -> DecimalType { return numberToDecimal(readUInt()); }; + default_function = [this]() -> DecimalType { return numberToDecimal(field_descriptor.default_value_uint32()); }; + break; + } + + case FieldTypeId::TYPE_INT64: + { + write_function = [this](const DecimalType & decimal) { writeInt(decimalToNumber(decimal)); }; + read_function = [this]() -> DecimalType { return numberToDecimal(readInt()); }; + default_function = [this]() -> DecimalType { return numberToDecimal(field_descriptor.default_value_int64()); }; + break; + } + + case FieldTypeId::TYPE_SINT64: + { + write_function = [this](const DecimalType & decimal) { writeSInt(decimalToNumber(decimal)); }; + read_function = [this]() -> DecimalType { return numberToDecimal(readSInt()); }; + default_function = [this]() -> DecimalType { return numberToDecimal(field_descriptor.default_value_int64()); }; + break; + } + + case FieldTypeId::TYPE_UINT64: + { + write_function = [this](const DecimalType & decimal) { writeUInt(decimalToNumber(decimal)); }; + read_function = [this]() -> DecimalType { return numberToDecimal(readUInt()); }; + default_function = [this]() -> DecimalType { return numberToDecimal(field_descriptor.default_value_uint64()); }; + break; + } + + case FieldTypeId::TYPE_FIXED32: + { + write_function = [this](const DecimalType & decimal) { writeFixed(decimalToNumber(decimal)); }; + read_function = [this]() -> DecimalType { return numberToDecimal(readFixed()); }; + default_function = [this]() -> DecimalType { return numberToDecimal(field_descriptor.default_value_uint32()); }; + break; + } + + case FieldTypeId::TYPE_SFIXED32: + { + write_function = [this](const DecimalType & decimal) { writeFixed(decimalToNumber(decimal)); }; + read_function = [this]() -> DecimalType { return numberToDecimal(readFixed()); }; + default_function = [this]() -> DecimalType { return numberToDecimal(field_descriptor.default_value_int32()); }; + break; + } + + case FieldTypeId::TYPE_FIXED64: + { + write_function = [this](const DecimalType & decimal) { writeFixed(decimalToNumber(decimal)); }; + read_function = [this]() -> DecimalType { return numberToDecimal(readFixed()); }; + default_function = [this]() -> DecimalType { return numberToDecimal(field_descriptor.default_value_uint64()); }; + break; + } + + case FieldTypeId::TYPE_SFIXED64: + { + write_function = [this](const DecimalType & decimal) { writeFixed(decimalToNumber(decimal)); }; + read_function = [this]() -> DecimalType { return numberToDecimal(readFixed()); }; + default_function = [this]() -> DecimalType { return numberToDecimal(field_descriptor.default_value_int64()); }; + break; + } + + case FieldTypeId::TYPE_FLOAT: + { + write_function = [this](const DecimalType & decimal) { writeFixed(decimalToNumber(decimal)); }; + read_function = [this]() -> DecimalType { return numberToDecimal(readFixed()); }; + default_function = [this]() -> DecimalType { return numberToDecimal(field_descriptor.default_value_float()); }; + break; + } + + case FieldTypeId::TYPE_DOUBLE: + { + write_function = [this](const DecimalType & decimal) { writeFixed(decimalToNumber(decimal)); }; + read_function = [this]() -> DecimalType { return numberToDecimal(readFixed()); }; + default_function = [this]() -> DecimalType { return numberToDecimal(field_descriptor.default_value_double()); }; + break; + } + + case FieldTypeId::TYPE_BOOL: + { + if (std::is_same_v) + failedToSetFunctions(); + else + { + write_function = [this](const DecimalType & decimal) + { + if (decimal.value == 0) + writeInt(0); + else if (DecimalComparison::compare(decimal, 1, scale, 0)) + writeInt(1); + else + { + WriteBufferFromOwnString buf; + writeText(decimal, scale, buf); + cannotConvertValue(buf.str(), TypeName::get(), field_descriptor.type_name()); + } + }; + + read_function = [this]() -> DecimalType + { + UInt64 u64 = readUInt(); + if (u64 < 2) + return numberToDecimal(static_cast(u64 != 0)); + else + cannotConvertValue(toString(u64), field_descriptor.type_name(), TypeName::get()); + }; + + default_function = [this]() -> DecimalType + { + return numberToDecimal(static_cast(field_descriptor.default_value_bool())); + }; + } + break; + } + + case FieldTypeId::TYPE_STRING: + case FieldTypeId::TYPE_BYTES: + { + write_function = [this](const DecimalType & decimal) + { + decimalToString(decimal, text_buffer); + writeStr(text_buffer); + }; + + read_function = [this]() -> DecimalType + { + readStr(text_buffer); + return stringToDecimal(text_buffer); + }; + + default_function = [this]() -> DecimalType { return stringToDecimal(field_descriptor.default_value_string()); }; + break; + } + + default: + failedToSetFunctions(); + } + } + + [[noreturn]] void failedToSetFunctions() + { + throw Exception( + "The field " + quoteString(field_descriptor.full_name()) + " has an incompatible type " + field_descriptor.type_name() + + " for serialization of the data type " + quoteString(TypeName::get()), + ErrorCodes::DATA_TYPE_INCOMPATIBLE_WITH_PROTOBUF_FIELD); + } + + DecimalType getDefaultDecimal() + { + if (!default_decimal) + default_decimal = default_function(); + return *default_decimal; + } + + template + DecimalType numberToDecimal(NumberType value) const + { + return convertToDecimal, DataTypeDecimal>(value, scale); + } + + template + NumberType decimalToNumber(const DecimalType & decimal) const + { + return DecimalUtils::convertTo(decimal, scale); + } + + void decimalToString(const DecimalType & decimal, String & str) const + { + WriteBufferFromString buf{str}; + if constexpr (std::is_same_v) + writeDateTimeText(decimal, scale, buf); + else + writeText(decimal, scale, buf); + } + + DecimalType stringToDecimal(const String & str) const + { + ReadBufferFromString buf(str); + DecimalType decimal{0}; + if constexpr (std::is_same_v) + readDateTime64Text(decimal, scale, buf); + else + DataTypeDecimal::readText(decimal, buf, precision, scale); + return decimal; + } + + const UInt32 precision; + const UInt32 scale; + std::function write_function; + std::function read_function; + std::function default_function; + std::optional default_decimal; + String text_buffer; + }; + + using ProtobufSerializerDateTime64 = ProtobufSerializerDecimal; + + + /// Serializes a ColumnVector containing dates to a field of any type except TYPE_MESSAGE, TYPE_GROUP, TYPE_BOOL, TYPE_ENUM. + class ProtobufSerializerDate : public ProtobufSerializerNumber + { + public: + ProtobufSerializerDate( + const FieldDescriptor & field_descriptor_, + const ProtobufReaderOrWriter & reader_or_writer_) + : ProtobufSerializerNumber(field_descriptor_, reader_or_writer_) + { + setFunctions(); + } + + private: + void setFunctions() + { + switch (field_typeid) + { + case FieldTypeId::TYPE_INT32: + case FieldTypeId::TYPE_SINT32: + case FieldTypeId::TYPE_UINT32: + case FieldTypeId::TYPE_INT64: + case FieldTypeId::TYPE_SINT64: + case FieldTypeId::TYPE_UINT64: + case FieldTypeId::TYPE_FIXED32: + case FieldTypeId::TYPE_SFIXED32: + case FieldTypeId::TYPE_FIXED64: + case FieldTypeId::TYPE_SFIXED64: + case FieldTypeId::TYPE_FLOAT: + case FieldTypeId::TYPE_DOUBLE: + break; /// already set in ProtobufSerializerNumber::setFunctions(). + + case FieldTypeId::TYPE_STRING: + case FieldTypeId::TYPE_BYTES: + { + write_function = [this](UInt16 value) + { + dateToString(static_cast(value), text_buffer); + writeStr(text_buffer); + }; + + read_function = [this]() -> UInt16 + { + readStr(text_buffer); + return stringToDate(text_buffer); + }; + + default_function = [this]() -> UInt16 { return stringToDate(field_descriptor.default_value_string()); }; + break; + } + + default: + failedToSetFunctions(); + } + } + + static void dateToString(DayNum date, String & str) + { + WriteBufferFromString buf{str}; + writeText(date, buf); + } + + static DayNum stringToDate(const String & str) + { + DayNum date; + ReadBufferFromString buf{str}; + readDateText(date, buf); + return date; + } + + [[noreturn]] void failedToSetFunctions() + { + throw Exception( + "The field " + quoteString(field_descriptor.full_name()) + " has an incompatible type " + field_descriptor.type_name() + + " for serialization of the data type 'Date'", + ErrorCodes::DATA_TYPE_INCOMPATIBLE_WITH_PROTOBUF_FIELD); + } + }; + + + /// Serializes a ColumnVector containing dates to a field of any type except TYPE_MESSAGE, TYPE_GROUP, TYPE_BOOL, TYPE_ENUM. + class ProtobufSerializerDateTime : public ProtobufSerializerNumber + { + public: + ProtobufSerializerDateTime( + const FieldDescriptor & field_descriptor_, const ProtobufReaderOrWriter & reader_or_writer_) + : ProtobufSerializerNumber(field_descriptor_, reader_or_writer_) + { + setFunctions(); + } + + protected: + void setFunctions() + { + switch (field_typeid) + { + case FieldTypeId::TYPE_INT32: + case FieldTypeId::TYPE_SINT32: + case FieldTypeId::TYPE_UINT32: + case FieldTypeId::TYPE_INT64: + case FieldTypeId::TYPE_SINT64: + case FieldTypeId::TYPE_UINT64: + case FieldTypeId::TYPE_FIXED32: + case FieldTypeId::TYPE_SFIXED32: + case FieldTypeId::TYPE_FIXED64: + case FieldTypeId::TYPE_SFIXED64: + case FieldTypeId::TYPE_FLOAT: + case FieldTypeId::TYPE_DOUBLE: + break; /// already set in ProtobufSerializerNumber::setFunctions(). + + case FieldTypeId::TYPE_STRING: + case FieldTypeId::TYPE_BYTES: + { + write_function = [this](UInt32 value) + { + dateTimeToString(value, text_buffer); + writeStr(text_buffer); + }; + + read_function = [this]() -> UInt32 + { + readStr(text_buffer); + return stringToDateTime(text_buffer); + }; + + default_function = [this]() -> UInt32 { return stringToDateTime(field_descriptor.default_value_string()); }; + break; + } + + default: + failedToSetFunctions(); + } + } + + static void dateTimeToString(time_t tm, String & str) + { + WriteBufferFromString buf{str}; + writeDateTimeText(tm, buf); + } + + static time_t stringToDateTime(const String & str) + { + ReadBufferFromString buf{str}; + time_t tm = 0; + readDateTimeText(tm, buf); + return tm; + } + + [[noreturn]] void failedToSetFunctions() + { + throw Exception( + "The field " + quoteString(field_descriptor.full_name()) + " has an incompatible type " + field_descriptor.type_name() + + " for serialization of the data type 'DateTime'", + ErrorCodes::DATA_TYPE_INCOMPATIBLE_WITH_PROTOBUF_FIELD); + } + }; + + + /// Serializes a ColumnVector containing UUIDs to a field of type TYPE_STRING or TYPE_BYTES. + class ProtobufSerializerUUID : public ProtobufSerializerNumber + { + public: + ProtobufSerializerUUID( + const google::protobuf::FieldDescriptor & field_descriptor_, + const ProtobufReaderOrWriter & reader_or_writer_) + : ProtobufSerializerNumber(field_descriptor_, reader_or_writer_) + { + setFunctions(); + } + + private: + void setFunctions() + { + if ((field_typeid != FieldTypeId::TYPE_STRING) && (field_typeid != FieldTypeId::TYPE_BYTES)) + { + throw Exception( + "The field " + quoteString(field_descriptor.full_name()) + " has an incompatible type " + field_descriptor.type_name() + + " for serialization of the data type UUID", + ErrorCodes::DATA_TYPE_INCOMPATIBLE_WITH_PROTOBUF_FIELD); + } + + write_function = [this](UInt128 value) + { + uuidToString(static_cast(value), text_buffer); + writeStr(text_buffer); + }; + + read_function = [this]() -> UInt128 + { + readStr(text_buffer); + return stringToUUID(text_buffer); + }; + + default_function = [this]() -> UInt128 { return stringToUUID(field_descriptor.default_value_string()); }; + } + + static void uuidToString(const UUID & uuid, String & str) + { + WriteBufferFromString buf{str}; + writeText(uuid, buf); + } + + static UUID stringToUUID(const String & str) + { + ReadBufferFromString buf{str}; + UUID uuid; + readUUIDText(uuid, buf); + return uuid; + } + }; + + + using ProtobufSerializerInterval = ProtobufSerializerNumber; + + + /// Serializes a ColumnAggregateFunction to a field of type TYPE_STRING or TYPE_BYTES. + class ProtobufSerializerAggregateFunction : public ProtobufSerializerSingleValue + { + public: + ProtobufSerializerAggregateFunction( + const std::shared_ptr & aggregate_function_data_type_, + const google::protobuf::FieldDescriptor & field_descriptor_, + const ProtobufReaderOrWriter & reader_or_writer_) + : ProtobufSerializerSingleValue(field_descriptor_, reader_or_writer_) + , aggregate_function_data_type(aggregate_function_data_type_) + , aggregate_function(aggregate_function_data_type->getFunction()) + { + if ((field_typeid != FieldTypeId::TYPE_STRING) && (field_typeid != FieldTypeId::TYPE_BYTES)) + { + throw Exception( + "The field " + quoteString(field_descriptor.full_name()) + " has an incompatible type " + field_descriptor.type_name() + + " for serialization of the data type " + quoteString(aggregate_function_data_type->getName()), + ErrorCodes::DATA_TYPE_INCOMPATIBLE_WITH_PROTOBUF_FIELD); + } + } + + void writeRow(size_t row_num) override + { + const auto & column_af = assert_cast(*column); + dataToString(column_af.getData()[row_num], text_buffer); + writeStr(text_buffer); + } + + void readRow(size_t row_num) override + { + auto & column_af = assert_cast(column->assumeMutableRef()); + Arena & arena = column_af.createOrGetArena(); + AggregateDataPtr data; + readStr(text_buffer); + data = stringToData(text_buffer, arena); + + if (row_num < column_af.size()) + { + auto * old_data = std::exchange(column_af.getData()[row_num], data); + aggregate_function->destroy(old_data); + } + else + column_af.getData().push_back(data); + } + + void insertDefaults(size_t row_num) override + { + auto & column_af = assert_cast(column->assumeMutableRef()); + if (row_num < column_af.size()) + return; + + Arena & arena = column_af.createOrGetArena(); + AggregateDataPtr data = stringToData(field_descriptor.default_value_string(), arena); + column_af.getData().push_back(data); + } + + private: + void dataToString(ConstAggregateDataPtr data, String & str) const + { + WriteBufferFromString buf{str}; + aggregate_function->serialize(data, buf); + } + + AggregateDataPtr stringToData(const String & str, Arena & arena) const + { + size_t size_of_state = aggregate_function->sizeOfData(); + AggregateDataPtr data = arena.alignedAlloc(size_of_state, aggregate_function->alignOfData()); + try + { + aggregate_function->create(data); + ReadBufferFromMemory buf(str.data(), str.length()); + aggregate_function->deserialize(data, buf, &arena); + return data; + } + catch (...) + { + aggregate_function->destroy(data); + throw; + } + } + + const std::shared_ptr aggregate_function_data_type; + const AggregateFunctionPtr aggregate_function; + String text_buffer; + }; + + + /// Serializes a ColumnNullable. + class ProtobufSerializerNullable : public ProtobufSerializer + { + public: + explicit ProtobufSerializerNullable(std::unique_ptr nested_serializer_) + : nested_serializer(std::move(nested_serializer_)) + { + } + + void setColumns(const ColumnPtr * columns, [[maybe_unused]] size_t num_columns) override + { + assert(num_columns == 1); + column = columns[0]; + const auto & column_nullable = assert_cast(*column); + ColumnPtr nested_column = column_nullable.getNestedColumnPtr(); + nested_serializer->setColumns(&nested_column, 1); + } + + void setColumns(const MutableColumnPtr * columns, [[maybe_unused]] size_t num_columns) override + { + assert(num_columns == 1); + ColumnPtr column0 = columns[0]->getPtr(); + setColumns(&column0, 1); + } + + void writeRow(size_t row_num) override + { + const auto & column_nullable = assert_cast(*column); + const auto & null_map = column_nullable.getNullMapData(); + if (!null_map[row_num]) + nested_serializer->writeRow(row_num); + } + + void readRow(size_t row_num) override + { + auto & column_nullable = assert_cast(column->assumeMutableRef()); + auto & nested_column = column_nullable.getNestedColumn(); + auto & null_map = column_nullable.getNullMapData(); + size_t old_size = null_map.size(); + + nested_serializer->readRow(row_num); + + if (row_num < old_size) + { + null_map[row_num] = false; + } + else + { + size_t new_size = nested_column.size(); + if (new_size != old_size + 1) + throw Exception("Size of ColumnNullable is unexpected", ErrorCodes::LOGICAL_ERROR); + try + { + null_map.push_back(false); + } + catch (...) + { + nested_column.popBack(1); + throw; + } + } + } + + void insertDefaults(size_t row_num) override + { + auto & column_nullable = assert_cast(column->assumeMutableRef()); + if (row_num < column_nullable.size()) + return; + column_nullable.insertDefault(); + } + + private: + const std::unique_ptr nested_serializer; + ColumnPtr column; + }; + + + /// Serializes a ColumnMap. + class ProtobufSerializerMap : public ProtobufSerializer + { + public: + explicit ProtobufSerializerMap(std::unique_ptr nested_serializer_) + : nested_serializer(std::move(nested_serializer_)) + { + } + + void setColumns(const ColumnPtr * columns, [[maybe_unused]] size_t num_columns) override + { + assert(num_columns == 1); + const auto & column_map = assert_cast(*columns[0]); + ColumnPtr nested_column = column_map.getNestedColumnPtr(); + nested_serializer->setColumns(&nested_column, 1); + } + + void setColumns(const MutableColumnPtr * columns, [[maybe_unused]] size_t num_columns) override + { + assert(num_columns == 1); + ColumnPtr column0 = columns[0]->getPtr(); + setColumns(&column0, 1); + } + + void writeRow(size_t row_num) override { nested_serializer->writeRow(row_num); } + void readRow(size_t row_num) override { nested_serializer->readRow(row_num); } + void insertDefaults(size_t row_num) override { nested_serializer->insertDefaults(row_num); } + + private: + const std::unique_ptr nested_serializer; + }; + + + /// Serializes a ColumnLowCardinality. + class ProtobufSerializerLowCardinality : public ProtobufSerializer + { + public: + explicit ProtobufSerializerLowCardinality(std::unique_ptr nested_serializer_) + : nested_serializer(std::move(nested_serializer_)) + { + } + + void setColumns(const ColumnPtr * columns, [[maybe_unused]] size_t num_columns) override + { + assert(num_columns == 1); + column = columns[0]; + const auto & column_lc = assert_cast(*column); + ColumnPtr nested_column = column_lc.getDictionary().getNestedColumn(); + nested_serializer->setColumns(&nested_column, 1); + read_value_column_set = false; + } + + void setColumns(const MutableColumnPtr * columns, [[maybe_unused]] size_t num_columns) override + { + assert(num_columns == 1); + ColumnPtr column0 = columns[0]->getPtr(); + setColumns(&column0, 1); + } + + void writeRow(size_t row_num) override + { + const auto & column_lc = assert_cast(*column); + size_t unique_row_number = column_lc.getIndexes().getUInt(row_num); + nested_serializer->writeRow(unique_row_number); + } + + void readRow(size_t row_num) override + { + auto & column_lc = assert_cast(column->assumeMutableRef()); + + if (!read_value_column_set) + { + if (!read_value_column) + { + ColumnPtr nested_column = column_lc.getDictionary().getNestedColumn(); + read_value_column = nested_column->cloneEmpty(); + } + nested_serializer->setColumns(&read_value_column, 1); + read_value_column_set = true; + } + + read_value_column->popBack(read_value_column->size()); + nested_serializer->readRow(0); + + if (row_num < column_lc.size()) + { + if (row_num != column_lc.size() - 1) + throw Exception("Cannot replace an element in the middle of ColumnLowCardinality", ErrorCodes::LOGICAL_ERROR); + column_lc.popBack(1); + } + + column_lc.insertFromFullColumn(*read_value_column, 0); + } + + void insertDefaults(size_t row_num) override + { + auto & column_lc = assert_cast(column->assumeMutableRef()); + if (row_num < column_lc.size()) + return; + + if (!default_value_column) + { + ColumnPtr nested_column = column_lc.getDictionary().getNestedColumn(); + default_value_column = nested_column->cloneEmpty(); + nested_serializer->setColumns(&default_value_column, 1); + nested_serializer->insertDefaults(0); + read_value_column_set = false; + } + + column_lc.insertFromFullColumn(*default_value_column, 0); + } + + private: + const std::unique_ptr nested_serializer; + ColumnPtr column; + MutableColumnPtr read_value_column; + bool read_value_column_set = false; + MutableColumnPtr default_value_column; + }; + + + /// Serializes a ColumnArray to a repeated field. + class ProtobufSerializerArray : public ProtobufSerializer + { + public: + explicit ProtobufSerializerArray(std::unique_ptr element_serializer_) + : element_serializer(std::move(element_serializer_)) + { + } + + void setColumns(const ColumnPtr * columns, [[maybe_unused]] size_t num_columns) override + { + assert(num_columns == 1); + column = columns[0]; + const auto & column_array = assert_cast(*column); + ColumnPtr data_column = column_array.getDataPtr(); + element_serializer->setColumns(&data_column, 1); + } + + void setColumns(const MutableColumnPtr * columns, [[maybe_unused]] size_t num_columns) override + { + assert(num_columns == 1); + ColumnPtr column0 = columns[0]->getPtr(); + setColumns(&column0, 1); + } + + void writeRow(size_t row_num) override + { + const auto & column_array = assert_cast(*column); + const auto & offsets = column_array.getOffsets(); + for (size_t i : ext::range(offsets[row_num - 1], offsets[row_num])) + element_serializer->writeRow(i); + } + + void readRow(size_t row_num) override + { + auto & column_array = assert_cast(column->assumeMutableRef()); + auto & offsets = column_array.getOffsets(); + size_t old_size = offsets.size(); + if (row_num + 1 < old_size) + throw Exception("Cannot replace an element in the middle of ColumnArray", ErrorCodes::LOGICAL_ERROR); + auto data_column = column_array.getDataPtr(); + size_t old_data_size = data_column->size(); + + try + { + element_serializer->readRow(old_data_size); + size_t data_size = data_column->size(); + if (data_size != old_data_size + 1) + throw Exception("Size of ColumnArray is unexpected", ErrorCodes::LOGICAL_ERROR); + + if (row_num < old_size) + offsets.back() = data_size; + else + offsets.push_back(data_size); + } + catch (...) + { + if (data_column->size() > old_data_size) + data_column->assumeMutableRef().popBack(data_column->size() - old_data_size); + if (offsets.size() > old_size) + column_array.getOffsetsColumn().popBack(offsets.size() - old_size); + throw; + } + } + + void insertDefaults(size_t row_num) override + { + auto & column_array = assert_cast(column->assumeMutableRef()); + if (row_num < column_array.size()) + return; + column_array.insertDefault(); + } + + private: + const std::unique_ptr element_serializer; + ColumnPtr column; + }; + + + /// Serializes a ColumnTuple as a repeated field (just like we serialize arrays). + class ProtobufSerializerTupleAsArray : public ProtobufSerializer + { + public: + ProtobufSerializerTupleAsArray( + const std::shared_ptr & tuple_data_type_, + const FieldDescriptor & field_descriptor_, + std::vector> element_serializers_) + : tuple_data_type(tuple_data_type_) + , tuple_size(tuple_data_type->getElements().size()) + , field_descriptor(field_descriptor_) + , element_serializers(std::move(element_serializers_)) + { + assert(tuple_size); + assert(tuple_size == element_serializers.size()); + } + + void setColumns(const ColumnPtr * columns, [[maybe_unused]] size_t num_columns) override + { + assert(num_columns == 1); + column = columns[0]; + const auto & column_tuple = assert_cast(*column); + for (size_t i : ext::range(tuple_size)) + { + auto element_column = column_tuple.getColumnPtr(i); + element_serializers[i]->setColumns(&element_column, 1); + } + current_element_index = 0; + } + + void setColumns(const MutableColumnPtr * columns, [[maybe_unused]] size_t num_columns) override + { + assert(num_columns == 1); + ColumnPtr column0 = columns[0]->getPtr(); + setColumns(&column0, 1); + } + + void writeRow(size_t row_num) override + { + for (size_t i : ext::range(tuple_size)) + element_serializers[i]->writeRow(row_num); + } + + void readRow(size_t row_num) override + { + auto & column_tuple = assert_cast(column->assumeMutableRef()); + + size_t old_size = column_tuple.size(); + if (row_num >= old_size) + current_element_index = 0; + + insertDefaults(row_num); + + if (current_element_index >= tuple_size) + { + throw Exception( + "Too many (" + std::to_string(current_element_index) + ") elements was read from the field " + + field_descriptor.full_name() + " to fit in the data type " + tuple_data_type->getName(), + ErrorCodes::PROTOBUF_BAD_CAST); + } + + element_serializers[current_element_index]->readRow(row_num); + ++current_element_index; + } + + void insertDefaults(size_t row_num) override + { + auto & column_tuple = assert_cast(column->assumeMutableRef()); + size_t old_size = column_tuple.size(); + + if (row_num > old_size) + return; + + try + { + for (size_t i : ext::range(tuple_size)) + element_serializers[i]->insertDefaults(row_num); + } + catch (...) + { + for (size_t i : ext::range(tuple_size)) + { + auto element_column = column_tuple.getColumnPtr(i)->assumeMutable(); + if (element_column->size() > old_size) + element_column->popBack(element_column->size() - old_size); + } + throw; + } + } + + private: + const std::shared_ptr tuple_data_type; + const size_t tuple_size; + const FieldDescriptor & field_descriptor; + const std::vector> element_serializers; + ColumnPtr column; + size_t current_element_index = 0; + }; + + + /// Serializes a message (root or nested) in the protobuf schema. + class ProtobufSerializerMessage : public ProtobufSerializer + { + public: + struct FieldDesc + { + size_t column_index; + size_t num_columns; + const FieldDescriptor * field_descriptor; + std::unique_ptr field_serializer; + }; + + ProtobufSerializerMessage( + std::vector field_descs_, + const FieldDescriptor * parent_field_descriptor_, + bool with_length_delimiter_, + const ProtobufReaderOrWriter & reader_or_writer_) + : parent_field_descriptor(parent_field_descriptor_) + , with_length_delimiter(with_length_delimiter_) + , should_skip_if_empty(parent_field_descriptor ? shouldSkipZeroOrEmpty(*parent_field_descriptor) : false) + , reader(reader_or_writer_.reader) + , writer(reader_or_writer_.writer) + { + field_infos.reserve(field_descs_.size()); + for (auto & desc : field_descs_) + field_infos.emplace_back(desc.column_index, desc.num_columns, *desc.field_descriptor, std::move(desc.field_serializer)); + + std::sort(field_infos.begin(), field_infos.end(), + [](const FieldInfo & lhs, const FieldInfo & rhs) { return lhs.field_tag < rhs.field_tag; }); + + for (size_t i : ext::range(field_infos.size())) + field_index_by_field_tag.emplace(field_infos[i].field_tag, i); + } + + void setColumns(const ColumnPtr * columns_, size_t num_columns_) override + { + columns.assign(columns_, columns_ + num_columns_); + + for (const FieldInfo & info : field_infos) + info.field_serializer->setColumns(columns.data() + info.column_index, info.num_columns); + + if (reader) + { + missing_column_indices.clear(); + missing_column_indices.reserve(num_columns_); + size_t current_idx = 0; + for (const FieldInfo & info : field_infos) + { + while (current_idx < info.column_index) + missing_column_indices.push_back(current_idx++); + current_idx = info.column_index + info.num_columns; + } + while (current_idx < num_columns_) + missing_column_indices.push_back(current_idx++); + } + } + + void setColumns(const MutableColumnPtr * columns_, size_t num_columns_) override + { + Columns cols; + cols.reserve(num_columns_); + for (size_t i : ext::range(num_columns_)) + cols.push_back(columns_[i]->getPtr()); + setColumns(cols.data(), cols.size()); + } + + void writeRow(size_t row_num) override + { + if (parent_field_descriptor) + writer->startNestedMessage(); + else + writer->startMessage(); + + for (const FieldInfo & info : field_infos) + { + if (info.should_pack_repeated) + writer->startRepeatedPack(); + info.field_serializer->writeRow(row_num); + if (info.should_pack_repeated) + writer->endRepeatedPack(info.field_tag, true); + } + + if (parent_field_descriptor) + { + bool is_group = (parent_field_descriptor->type() == FieldTypeId::TYPE_GROUP); + writer->endNestedMessage(parent_field_descriptor->number(), is_group, should_skip_if_empty); + } + else + writer->endMessage(with_length_delimiter); + } + + void readRow(size_t row_num) override + { + if (parent_field_descriptor) + reader->startNestedMessage(); + else + reader->startMessage(with_length_delimiter); + + if (!field_infos.empty()) + { + last_field_index = 0; + last_field_tag = field_infos[0].field_tag; + size_t old_size = columns.empty() ? 0 : columns[0]->size(); + + try + { + int field_tag; + while (reader->readFieldNumber(field_tag)) + { + size_t field_index = findFieldIndexByFieldTag(field_tag); + if (field_index == static_cast(-1)) + continue; + auto * field_serializer = field_infos[field_index].field_serializer.get(); + field_serializer->readRow(row_num); + field_infos[field_index].field_read = true; + } + + for (auto & info : field_infos) + { + if (info.field_read) + info.field_read = false; + else + info.field_serializer->insertDefaults(row_num); + } + } + catch (...) + { + for (auto & column : columns) + { + if (column->size() > old_size) + column->assumeMutableRef().popBack(column->size() - old_size); + } + throw; + } + } + + if (parent_field_descriptor) + reader->endNestedMessage(); + else + reader->endMessage(false); + addDefaultsToMissingColumns(row_num); + } + + void insertDefaults(size_t row_num) override + { + for (const FieldInfo & info : field_infos) + info.field_serializer->insertDefaults(row_num); + addDefaultsToMissingColumns(row_num); + } + + private: + size_t findFieldIndexByFieldTag(int field_tag) + { + while (true) + { + if (field_tag == last_field_tag) + return last_field_index; + if (field_tag < last_field_tag) + break; + if (++last_field_index >= field_infos.size()) + break; + last_field_tag = field_infos[last_field_index].field_tag; + } + last_field_tag = field_tag; + auto it = field_index_by_field_tag.find(field_tag); + if (it == field_index_by_field_tag.end()) + last_field_index = static_cast(-1); + else + last_field_index = it->second; + return last_field_index; + } + + void addDefaultsToMissingColumns(size_t row_num) + { + for (size_t column_idx : missing_column_indices) + { + auto & column = columns[column_idx]; + size_t old_size = column->size(); + if (row_num >= old_size) + column->assumeMutableRef().insertDefault(); + } + } + + struct FieldInfo + { + FieldInfo( + size_t column_index_, + size_t num_columns_, + const FieldDescriptor & field_descriptor_, + std::unique_ptr field_serializer_) + : column_index(column_index_) + , num_columns(num_columns_) + , field_descriptor(&field_descriptor_) + , field_tag(field_descriptor_.number()) + , should_pack_repeated(shouldPackRepeated(field_descriptor_)) + , field_serializer(std::move(field_serializer_)) + { + } + size_t column_index; + size_t num_columns; + const FieldDescriptor * field_descriptor; + int field_tag; + bool should_pack_repeated; + std::unique_ptr field_serializer; + bool field_read = false; + }; + + const FieldDescriptor * const parent_field_descriptor; + const bool with_length_delimiter; + const bool should_skip_if_empty; + ProtobufReader * const reader; + ProtobufWriter * const writer; + std::vector field_infos; + std::unordered_map field_index_by_field_tag; + Columns columns; + std::vector missing_column_indices; + int last_field_tag = 0; + size_t last_field_index = static_cast(-1); + }; + + + /// Serializes a tuple with explicit names as a nested message. + class ProtobufSerializerTupleAsNestedMessage : public ProtobufSerializer + { + public: + explicit ProtobufSerializerTupleAsNestedMessage(std::unique_ptr nested_message_serializer_) + : nested_message_serializer(std::move(nested_message_serializer_)) + { + } + + void setColumns(const ColumnPtr * columns, [[maybe_unused]] size_t num_columns) override + { + assert(num_columns == 1); + const auto & column_tuple = assert_cast(*columns[0]); + size_t tuple_size = column_tuple.tupleSize(); + assert(tuple_size); + Columns element_columns; + element_columns.reserve(tuple_size); + for (size_t i : ext::range(tuple_size)) + element_columns.emplace_back(column_tuple.getColumnPtr(i)); + nested_message_serializer->setColumns(element_columns.data(), element_columns.size()); + } + + void setColumns(const MutableColumnPtr * columns, [[maybe_unused]] size_t num_columns) override + { + assert(num_columns == 1); + ColumnPtr column0 = columns[0]->getPtr(); + setColumns(&column0, 1); + } + + void writeRow(size_t row_num) override { nested_message_serializer->writeRow(row_num); } + void readRow(size_t row_num) override { nested_message_serializer->readRow(row_num); } + void insertDefaults(size_t row_num) override { nested_message_serializer->insertDefaults(row_num); } + + private: + const std::unique_ptr nested_message_serializer; + }; + + + /// Serializes a flattened Nested data type (an array of tuples with explicit names) + /// as a repeated nested message. + class ProtobufSerializerFlattenedNestedAsArrayOfNestedMessages : public ProtobufSerializer + { + public: + explicit ProtobufSerializerFlattenedNestedAsArrayOfNestedMessages( + std::unique_ptr nested_message_serializer_) + : nested_message_serializer(std::move(nested_message_serializer_)) + { + } + + void setColumns(const ColumnPtr * columns, size_t num_columns) override + { + assert(num_columns); + data_columns.clear(); + data_columns.reserve(num_columns); + offset_columns.clear(); + offset_columns.reserve(num_columns); + + for (size_t i : ext::range(num_columns)) + { + const auto & column_array = assert_cast(*columns[i]); + data_columns.emplace_back(column_array.getDataPtr()); + offset_columns.emplace_back(column_array.getOffsetsPtr()); + } + + std::sort(offset_columns.begin(), offset_columns.end()); + offset_columns.erase(std::unique(offset_columns.begin(), offset_columns.end()), offset_columns.end()); + + nested_message_serializer->setColumns(data_columns.data(), data_columns.size()); + } + + void setColumns(const MutableColumnPtr * columns, size_t num_columns) override + { + Columns cols; + cols.reserve(num_columns); + for (size_t i : ext::range(num_columns)) + cols.push_back(columns[i]->getPtr()); + setColumns(cols.data(), cols.size()); + } + + void writeRow(size_t row_num) override + { + const auto & offset_column0 = assert_cast(*offset_columns[0]); + size_t start_offset = offset_column0.getElement(row_num - 1); + size_t end_offset = offset_column0.getElement(row_num); + for (size_t i : ext::range(1, offset_columns.size())) + { + const auto & offset_column = assert_cast(*offset_columns[i]); + if (offset_column.getElement(row_num) != end_offset) + throw Exception("Components of FlattenedNested have different sizes", ErrorCodes::PROTOBUF_BAD_CAST); + } + for (size_t i : ext::range(start_offset, end_offset)) + nested_message_serializer->writeRow(i); + } + + void readRow(size_t row_num) override + { + size_t old_size = offset_columns[0]->size(); + if (row_num + 1 < old_size) + throw Exception("Cannot replace an element in the middle of ColumnArray", ErrorCodes::LOGICAL_ERROR); + + size_t old_data_size = data_columns[0]->size(); + + try + { + nested_message_serializer->readRow(old_data_size); + size_t data_size = data_columns[0]->size(); + if (data_size != old_data_size + 1) + throw Exception("Unexpected number of elements of ColumnArray has been read", ErrorCodes::LOGICAL_ERROR); + + if (row_num < old_size) + { + for (auto & offset_column : offset_columns) + assert_cast(offset_column->assumeMutableRef()).getData().back() = data_size; + } + else + { + for (auto & offset_column : offset_columns) + assert_cast(offset_column->assumeMutableRef()).getData().push_back(data_size); + } + } + catch (...) + { + for (auto & data_column : data_columns) + { + if (data_column->size() > old_data_size) + data_column->assumeMutableRef().popBack(data_column->size() - old_data_size); + } + for (auto & offset_column : offset_columns) + { + if (offset_column->size() > old_size) + offset_column->assumeMutableRef().popBack(offset_column->size() - old_size); + } + throw; + } + } + + void insertDefaults(size_t row_num) override + { + size_t old_size = offset_columns[0]->size(); + if (row_num < old_size) + return; + + try + { + size_t data_size = data_columns[0]->size(); + for (auto & offset_column : offset_columns) + assert_cast(offset_column->assumeMutableRef()).getData().push_back(data_size); + } + catch (...) + { + for (auto & offset_column : offset_columns) + { + if (offset_column->size() > old_size) + offset_column->assumeMutableRef().popBack(offset_column->size() - old_size); + } + throw; + } + } + + private: + const std::unique_ptr nested_message_serializer; + Columns data_columns; + Columns offset_columns; + }; + + + /// Produces a tree of ProtobufSerializers which serializes a row as a protobuf message. + class ProtobufSerializerBuilder + { + public: + explicit ProtobufSerializerBuilder(const ProtobufReaderOrWriter & reader_or_writer_) : reader_or_writer(reader_or_writer_) {} + + std::unique_ptr buildMessageSerializer( + const Strings & column_names, + const DataTypes & data_types, + std::vector & missing_column_indices, + const MessageDescriptor & message_descriptor, + bool with_length_delimiter) + { + std::vector used_column_indices; + auto serializer = buildMessageSerializerImpl( + /* num_columns = */ column_names.size(), + column_names.data(), + data_types.data(), + used_column_indices, + message_descriptor, + with_length_delimiter, + /* parent_field_descriptor = */ nullptr); + + if (!serializer) + { + throw Exception( + "Not found matches between the names of the columns {" + boost::algorithm::join(column_names, ", ") + + "} and the fields {" + boost::algorithm::join(getFieldNames(message_descriptor), ", ") + "} of the message " + + quoteString(message_descriptor.full_name()) + " in the protobuf schema", + ErrorCodes::NO_COLUMNS_SERIALIZED_TO_PROTOBUF_FIELDS); + } + + missing_column_indices.clear(); + missing_column_indices.reserve(column_names.size() - used_column_indices.size()); + boost::range::set_difference(ext::range(column_names.size()), used_column_indices, + std::back_inserter(missing_column_indices)); + + return serializer; + } + + private: + /// Collects all field names from the message (used only to format error messages). + static Strings getFieldNames(const MessageDescriptor & message_descriptor) + { + Strings field_names; + field_names.reserve(message_descriptor.field_count()); + for (int i : ext::range(message_descriptor.field_count())) + field_names.emplace_back(message_descriptor.field(i)->name()); + return field_names; + } + + static bool columnNameEqualsToFieldName(const std::string_view & column_name, const FieldDescriptor & field_descriptor) + { + std::string_view suffix; + return columnNameStartsWithFieldName(column_name, field_descriptor, suffix) && suffix.empty(); + } + + /// Checks if a passed column's name starts with a specified field's name. + /// The function also assigns `suffix` to the rest part of the column's name + /// which doesn't match to the field's name. + /// The function requires that rest part of the column's name to be started with a dot '.' or underline '_', + /// but doesn't include those '.' or '_' characters into `suffix`. + static bool columnNameStartsWithFieldName(const std::string_view & column_name, const FieldDescriptor & field_descriptor, std::string_view & suffix) + { + size_t matching_length = 0; + const MessageDescriptor & containing_type = *field_descriptor.containing_type(); + if (containing_type.options().map_entry()) + { + /// Special case. Elements of the data type Map are named as "keys" and "values", + /// but they're internally named as "key" and "value" in protobuf schema. + if (field_descriptor.number() == 1) + { + if (ColumnNameWithProtobufFieldNameComparator::startsWith(column_name, "keys")) + matching_length = strlen("keys"); + else if (ColumnNameWithProtobufFieldNameComparator::startsWith(column_name, "key")) + matching_length = strlen("key"); + } + else if (field_descriptor.number() == 2) + { + if (ColumnNameWithProtobufFieldNameComparator::startsWith(column_name, "values")) + matching_length = strlen("values"); + else if (ColumnNameWithProtobufFieldNameComparator::startsWith(column_name, "value")) + matching_length = strlen("value"); + } + } + if (!matching_length && ColumnNameWithProtobufFieldNameComparator::startsWith(column_name, field_descriptor.name())) + { + matching_length = field_descriptor.name().length(); + } + if (column_name.length() == matching_length) + return true; + if ((column_name.length() < matching_length + 2) || !field_descriptor.message_type()) + return false; + char first_char_after_matching = column_name[matching_length]; + if (!ColumnNameWithProtobufFieldNameComparator::equals(first_char_after_matching, '.')) + return false; + suffix = column_name.substr(matching_length + 1); + return true; + } + + /// Finds fields in the protobuf message which can be considered as matching + /// for a specified column's name. The found fields can be nested messages, + /// for that case suffixes are also returned. + /// This is only the first filter, buildMessageSerializerImpl() does other checks after calling this function. + static bool findFieldsByColumnName( + const std::string_view & column_name, + const MessageDescriptor & message_descriptor, + std::vector> & out_field_descriptors_with_suffixes) + { + out_field_descriptors_with_suffixes.clear(); + + /// Find all fields which have the same name as column's name (case-insensitively); i.e. we're checking + /// field_name == column_name. + for (int i : ext::range(message_descriptor.field_count())) + { + const auto & field_descriptor = *message_descriptor.field(i); + if (columnNameEqualsToFieldName(column_name, field_descriptor)) + { + out_field_descriptors_with_suffixes.emplace_back(&field_descriptor, std::string_view{}); + break; + } + } + + if (!out_field_descriptors_with_suffixes.empty()) + return true; /// We have an exact match, no need to compare prefixes. + + /// Find all fields which name is used as prefix in column's name; i.e. we're checking + /// column_name == field_name + '.' + nested_message_field_name + for (int i : ext::range(message_descriptor.field_count())) + { + const auto & field_descriptor = *message_descriptor.field(i); + std::string_view suffix; + if (columnNameStartsWithFieldName(column_name, field_descriptor, suffix)) + { + out_field_descriptors_with_suffixes.emplace_back(&field_descriptor, suffix); + } + } + + /// Shorter suffixes first. + std::sort(out_field_descriptors_with_suffixes.begin(), out_field_descriptors_with_suffixes.end(), + [](const std::pair & f1, + const std::pair & f2) + { + return f1.second.length() < f2.second.length(); + }); + + return !out_field_descriptors_with_suffixes.empty(); + } + + /// Builds a serializer for a protobuf message (root or nested). + template + std::unique_ptr buildMessageSerializerImpl( + size_t num_columns, + const StringOrStringViewT * column_names, + const DataTypePtr * data_types, + std::vector & used_column_indices, + const MessageDescriptor & message_descriptor, + bool with_length_delimiter, + const FieldDescriptor * parent_field_descriptor) + { + std::vector field_descs; + boost::container::flat_map field_descriptors_in_use; + + used_column_indices.clear(); + used_column_indices.reserve(num_columns); + + auto add_field_serializer = [&](size_t column_index_, + const std::string_view & column_name_, + size_t num_columns_, + const FieldDescriptor & field_descriptor_, + std::unique_ptr field_serializer_) + { + auto it = field_descriptors_in_use.find(&field_descriptor_); + if (it != field_descriptors_in_use.end()) + { + throw Exception( + "Multiple columns (" + backQuote(StringRef{field_descriptors_in_use[&field_descriptor_]}) + ", " + + backQuote(StringRef{column_name_}) + ") cannot be serialized to a single protobuf field " + + quoteString(field_descriptor_.full_name()), + ErrorCodes::MULTIPLE_COLUMNS_SERIALIZED_TO_SAME_PROTOBUF_FIELD); + } + + field_descs.push_back({column_index_, num_columns_, &field_descriptor_, std::move(field_serializer_)}); + field_descriptors_in_use.emplace(&field_descriptor_, column_name_); + }; + + std::vector> field_descriptors_with_suffixes; + + /// We're going through all the passed columns. + size_t column_idx = 0; + size_t next_column_idx = 1; + for (; column_idx != num_columns; column_idx = next_column_idx++) + { + auto column_name = column_names[column_idx]; + const auto & data_type = data_types[column_idx]; + + if (!findFieldsByColumnName(column_name, message_descriptor, field_descriptors_with_suffixes)) + continue; + + if ((field_descriptors_with_suffixes.size() == 1) && field_descriptors_with_suffixes[0].second.empty()) + { + /// Simple case: one column is serialized as one field. + const auto & field_descriptor = *field_descriptors_with_suffixes[0].first; + auto field_serializer = buildFieldSerializer(column_name, data_type, field_descriptor, field_descriptor.is_repeated()); + + if (field_serializer) + { + add_field_serializer(column_idx, column_name, 1, field_descriptor, std::move(field_serializer)); + used_column_indices.push_back(column_idx); + continue; + } + } + + for (const auto & [field_descriptor, suffix] : field_descriptors_with_suffixes) + { + if (!suffix.empty()) + { + /// Complex case: one or more columns are serialized as a nested message. + std::vector names_relative_to_nested_message; + names_relative_to_nested_message.reserve(num_columns - column_idx); + names_relative_to_nested_message.emplace_back(suffix); + + for (size_t j : ext::range(column_idx + 1, num_columns)) + { + std::string_view next_suffix; + if (!columnNameStartsWithFieldName(column_names[j], *field_descriptor, next_suffix)) + break; + names_relative_to_nested_message.emplace_back(next_suffix); + } + + /// Now we have up to `names_relative_to_nested_message.size()` sequential columns + /// which can be serialized as a nested message. + + /// Calculate how many of those sequential columns are arrays. + size_t num_arrays = 0; + for (size_t j : ext::range(column_idx, column_idx + names_relative_to_nested_message.size())) + { + if (data_types[j]->getTypeId() != TypeIndex::Array) + break; + ++num_arrays; + } + + /// We will try to serialize the sequential columns as one nested message, + /// then, if failed, as an array of nested messages (on condition those columns are array). + bool has_fallback_to_array_of_nested_messages = num_arrays && field_descriptor->is_repeated(); + + /// Try to serialize the sequential columns as one nested message. + try + { + std::vector used_column_indices_in_nested; + auto nested_message_serializer = buildMessageSerializerImpl( + names_relative_to_nested_message.size(), + names_relative_to_nested_message.data(), + &data_types[column_idx], + used_column_indices_in_nested, + *field_descriptor->message_type(), + false, + field_descriptor); + + if (nested_message_serializer) + { + for (size_t & idx_in_nested : used_column_indices_in_nested) + used_column_indices.push_back(idx_in_nested + column_idx); + + next_column_idx = used_column_indices.back() + 1; + add_field_serializer(column_idx, column_name, next_column_idx - column_idx, *field_descriptor, std::move(nested_message_serializer)); + break; + } + } + catch (Exception & e) + { + if ((e.code() != ErrorCodes::PROTOBUF_FIELD_NOT_REPEATED) || !has_fallback_to_array_of_nested_messages) + throw; + } + + if (has_fallback_to_array_of_nested_messages) + { + /// Try to serialize the sequential columns as an array of nested messages. + DataTypes array_nested_data_types; + array_nested_data_types.reserve(num_arrays); + for (size_t j : ext::range(column_idx, column_idx + num_arrays)) + array_nested_data_types.emplace_back(assert_cast(*data_types[j]).getNestedType()); + + std::vector used_column_indices_in_nested; + auto nested_message_serializer = buildMessageSerializerImpl( + array_nested_data_types.size(), + names_relative_to_nested_message.data(), + array_nested_data_types.data(), + used_column_indices_in_nested, + *field_descriptor->message_type(), + false, + field_descriptor); + + if (nested_message_serializer) + { + auto field_serializer = std::make_unique(std::move(nested_message_serializer)); + + for (size_t & idx_in_nested : used_column_indices_in_nested) + used_column_indices.push_back(idx_in_nested + column_idx); + + next_column_idx = used_column_indices.back() + 1; + add_field_serializer(column_idx, column_name, next_column_idx - column_idx, *field_descriptor, std::move(field_serializer)); + break; + } + } + } + } + } + + /// Check that we've found matching columns for all the required fields. + if ((message_descriptor.file()->syntax() == google::protobuf::FileDescriptor::SYNTAX_PROTO2) + && reader_or_writer.writer) + { + for (int i : ext::range(message_descriptor.field_count())) + { + const auto & field_descriptor = *message_descriptor.field(i); + if (field_descriptor.is_required() && !field_descriptors_in_use.count(&field_descriptor)) + throw Exception( + "Field " + quoteString(field_descriptor.full_name()) + " is required to be set", + ErrorCodes::NO_COLUMN_SERIALIZED_TO_REQUIRED_PROTOBUF_FIELD); + } + } + + if (field_descs.empty()) + return nullptr; + + return std::make_unique( + std::move(field_descs), parent_field_descriptor, with_length_delimiter, reader_or_writer); + } + + /// Builds a serializer for one-to-one match: + /// one column is serialized as one field in the protobuf message. + std::unique_ptr buildFieldSerializer( + const std::string_view & column_name, + const DataTypePtr & data_type, + const FieldDescriptor & field_descriptor, + bool allow_repeat) + { + auto data_type_id = data_type->getTypeId(); + switch (data_type_id) + { + case TypeIndex::UInt8: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::UInt16: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::UInt32: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::UInt64: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::UInt128: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::UInt256: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::Int8: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::Int16: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::Int32: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::Int64: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::Int128: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::Int256: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::Float32: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::Float64: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::Date: return std::make_unique(field_descriptor, reader_or_writer); + case TypeIndex::DateTime: return std::make_unique(field_descriptor, reader_or_writer); + case TypeIndex::DateTime64: return std::make_unique(assert_cast(*data_type), field_descriptor, reader_or_writer); + case TypeIndex::String: return std::make_unique>(field_descriptor, reader_or_writer); + case TypeIndex::FixedString: return std::make_unique>(assert_cast(*data_type), field_descriptor, reader_or_writer); + case TypeIndex::Enum8: return std::make_unique>(typeid_cast>(data_type), field_descriptor, reader_or_writer); + case TypeIndex::Enum16: return std::make_unique>(typeid_cast>(data_type), field_descriptor, reader_or_writer); + case TypeIndex::Decimal32: return std::make_unique>(assert_cast &>(*data_type), field_descriptor, reader_or_writer); + case TypeIndex::Decimal64: return std::make_unique>(assert_cast &>(*data_type), field_descriptor, reader_or_writer); + case TypeIndex::Decimal128: return std::make_unique>(assert_cast &>(*data_type), field_descriptor, reader_or_writer); + case TypeIndex::Decimal256: return std::make_unique>(assert_cast &>(*data_type), field_descriptor, reader_or_writer); + case TypeIndex::UUID: return std::make_unique(field_descriptor, reader_or_writer); + case TypeIndex::Interval: return std::make_unique(field_descriptor, reader_or_writer); + case TypeIndex::AggregateFunction: return std::make_unique(typeid_cast>(data_type), field_descriptor, reader_or_writer); + + case TypeIndex::Nullable: + { + const auto & nullable_data_type = assert_cast(*data_type); + auto nested_serializer = buildFieldSerializer(column_name, nullable_data_type.getNestedType(), field_descriptor, allow_repeat); + if (!nested_serializer) + return nullptr; + return std::make_unique(std::move(nested_serializer)); + } + + case TypeIndex::LowCardinality: + { + const auto & low_cardinality_data_type = assert_cast(*data_type); + auto nested_serializer + = buildFieldSerializer(column_name, low_cardinality_data_type.getDictionaryType(), field_descriptor, allow_repeat); + if (!nested_serializer) + return nullptr; + return std::make_unique(std::move(nested_serializer)); + } + + case TypeIndex::Map: + { + const auto & map_data_type = assert_cast(*data_type); + auto nested_serializer = buildFieldSerializer(column_name, map_data_type.getNestedType(), field_descriptor, allow_repeat); + if (!nested_serializer) + return nullptr; + return std::make_unique(std::move(nested_serializer)); + } + + case TypeIndex::Array: + { + /// Array is serialized as a repeated field. + const auto & array_data_type = assert_cast(*data_type); + + if (!allow_repeat) + { + throw Exception( + "The field " + quoteString(field_descriptor.full_name()) + + " must be repeated in the protobuf schema to match the column " + backQuote(StringRef{column_name}), + ErrorCodes::PROTOBUF_FIELD_NOT_REPEATED); + } + + auto nested_serializer = buildFieldSerializer(column_name, array_data_type.getNestedType(), field_descriptor, + /* allow_repeat = */ false); // We do our repeating now, so for nested type we forget about the repeating. + if (!nested_serializer) + return nullptr; + return std::make_unique(std::move(nested_serializer)); + } + + case TypeIndex::Tuple: + { + /// Tuple is serialized in one of two ways: + /// 1) If the tuple has explicit names then it can be serialized as a nested message. + /// 2) Any tuple can be serialized as a repeated field, just like Array. + const auto & tuple_data_type = assert_cast(*data_type); + size_t size_of_tuple = tuple_data_type.getElements().size(); + + if (tuple_data_type.haveExplicitNames() && field_descriptor.message_type()) + { + /// Try to serialize as a nested message. + std::vector used_column_indices; + auto nested_message_serializer = buildMessageSerializerImpl( + size_of_tuple, + tuple_data_type.getElementNames().data(), + tuple_data_type.getElements().data(), + used_column_indices, + *field_descriptor.message_type(), + false, + &field_descriptor); + + if (!nested_message_serializer) + { + throw Exception( + "Not found matches between the names of the tuple's elements {" + + boost::algorithm::join(tuple_data_type.getElementNames(), ", ") + "} and the fields {" + + boost::algorithm::join(getFieldNames(*field_descriptor.message_type()), ", ") + "} of the message " + + quoteString(field_descriptor.message_type()->full_name()) + " in the protobuf schema", + ErrorCodes::NO_COLUMNS_SERIALIZED_TO_PROTOBUF_FIELDS); + } + + return std::make_unique(std::move(nested_message_serializer)); + } + + /// Serialize as a repeated field. + if (!allow_repeat && (size_of_tuple > 1)) + { + throw Exception( + "The field " + quoteString(field_descriptor.full_name()) + + " must be repeated in the protobuf schema to match the column " + backQuote(StringRef{column_name}), + ErrorCodes::PROTOBUF_FIELD_NOT_REPEATED); + } + + std::vector> nested_serializers; + for (const auto & nested_data_type : tuple_data_type.getElements()) + { + auto nested_serializer = buildFieldSerializer(column_name, nested_data_type, field_descriptor, + /* allow_repeat = */ false); // We do our repeating now, so for nested type we forget about the repeating. + if (!nested_serializer) + break; + nested_serializers.push_back(std::move(nested_serializer)); + } + + if (nested_serializers.size() != size_of_tuple) + return nullptr; + + return std::make_unique( + typeid_cast>(data_type), + field_descriptor, + std::move(nested_serializers)); + } + + default: + throw Exception("Unknown data type: " + data_type->getName(), ErrorCodes::LOGICAL_ERROR); + } + } + + const ProtobufReaderOrWriter reader_or_writer; + }; +} + + +std::unique_ptr ProtobufSerializer::create( + const Strings & column_names, + const DataTypes & data_types, + std::vector & missing_column_indices, + const google::protobuf::Descriptor & message_descriptor, + bool with_length_delimiter, + ProtobufReader & reader) +{ + return ProtobufSerializerBuilder(reader).buildMessageSerializer(column_names, data_types, missing_column_indices, message_descriptor, with_length_delimiter); +} + +std::unique_ptr ProtobufSerializer::create( + const Strings & column_names, + const DataTypes & data_types, + const google::protobuf::Descriptor & message_descriptor, + bool with_length_delimiter, + ProtobufWriter & writer) +{ + std::vector missing_column_indices; + return ProtobufSerializerBuilder(writer).buildMessageSerializer(column_names, data_types, missing_column_indices, message_descriptor, with_length_delimiter); +} +} +#endif diff --git a/src/Formats/ProtobufSerializer.h b/src/Formats/ProtobufSerializer.h new file mode 100644 index 00000000000..86a2f2f36dd --- /dev/null +++ b/src/Formats/ProtobufSerializer.h @@ -0,0 +1,52 @@ +#pragma once + +#if !defined(ARCADIA_BUILD) +# include "config_formats.h" +#endif + +#if USE_PROTOBUF +# include + + +namespace google::protobuf { class Descriptor; } + +namespace DB +{ +class ProtobufReader; +class ProtobufWriter; +class IDataType; +using DataTypePtr = std::shared_ptr; +using DataTypes = std::vector; + + +/// Utility class, does all the work for serialization in the Protobuf format. +class ProtobufSerializer +{ +public: + virtual ~ProtobufSerializer() = default; + + virtual void setColumns(const ColumnPtr * columns, size_t num_columns) = 0; + virtual void writeRow(size_t row_num) = 0; + + virtual void setColumns(const MutableColumnPtr * columns, size_t num_columns) = 0; + virtual void readRow(size_t row_num) = 0; + virtual void insertDefaults(size_t row_num) = 0; + + static std::unique_ptr create( + const Strings & column_names, + const DataTypes & data_types, + std::vector & missing_column_indices, + const google::protobuf::Descriptor & message_descriptor, + bool with_length_delimiter, + ProtobufReader & reader); + + static std::unique_ptr create( + const Strings & column_names, + const DataTypes & data_types, + const google::protobuf::Descriptor & message_descriptor, + bool with_length_delimiter, + ProtobufWriter & writer); +}; + +} +#endif diff --git a/src/Formats/ProtobufWriter.cpp b/src/Formats/ProtobufWriter.cpp index e62d8fc4a58..ece4f78b1c8 100644 --- a/src/Formats/ProtobufWriter.cpp +++ b/src/Formats/ProtobufWriter.cpp @@ -1,29 +1,11 @@ #include "ProtobufWriter.h" #if USE_PROTOBUF -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include +# include namespace DB { -namespace ErrorCodes -{ - extern const int NOT_IMPLEMENTED; - extern const int NO_DATA_FOR_REQUIRED_PROTOBUF_FIELD; - extern const int PROTOBUF_BAD_CAST; - extern const int PROTOBUF_FIELD_NOT_REPEATED; -} - - namespace { constexpr size_t MAX_VARINT_SIZE = 10; @@ -81,66 +63,24 @@ namespace } void writeFieldNumber(UInt32 field_number, WireType wire_type, PODArray & buf) { writeVarint((field_number << 3) | wire_type, buf); } - - // Should we pack repeated values while storing them. - // It depends on type of the field in the protobuf schema and the syntax of that schema. - bool shouldPackRepeated(const google::protobuf::FieldDescriptor * field) - { - if (!field->is_repeated()) - return false; - switch (field->type()) - { - case google::protobuf::FieldDescriptor::TYPE_INT32: - case google::protobuf::FieldDescriptor::TYPE_UINT32: - case google::protobuf::FieldDescriptor::TYPE_SINT32: - case google::protobuf::FieldDescriptor::TYPE_INT64: - case google::protobuf::FieldDescriptor::TYPE_UINT64: - case google::protobuf::FieldDescriptor::TYPE_SINT64: - case google::protobuf::FieldDescriptor::TYPE_FIXED32: - case google::protobuf::FieldDescriptor::TYPE_SFIXED32: - case google::protobuf::FieldDescriptor::TYPE_FIXED64: - case google::protobuf::FieldDescriptor::TYPE_SFIXED64: - case google::protobuf::FieldDescriptor::TYPE_FLOAT: - case google::protobuf::FieldDescriptor::TYPE_DOUBLE: - case google::protobuf::FieldDescriptor::TYPE_BOOL: - case google::protobuf::FieldDescriptor::TYPE_ENUM: - break; - default: - return false; - } - if (field->options().has_packed()) - return field->options().packed(); - return field->file()->syntax() == google::protobuf::FileDescriptor::SYNTAX_PROTO3; - } - - // Should we omit null values (zero for numbers / empty string for strings) while storing them. - bool shouldSkipNullValue(const google::protobuf::FieldDescriptor * field) - { - return field->is_optional() && (field->file()->syntax() == google::protobuf::FileDescriptor::SYNTAX_PROTO3); - } } -// SimpleWriter is an utility class to serialize protobufs. -// Knows nothing about protobuf schemas, just provides useful functions to serialize data. -ProtobufWriter::SimpleWriter::SimpleWriter(WriteBuffer & out_, const bool use_length_delimiters_) +ProtobufWriter::ProtobufWriter(WriteBuffer & out_) : out(out_) - , current_piece_start(0) - , num_bytes_skipped(0) - , use_length_delimiters(use_length_delimiters_) { } -ProtobufWriter::SimpleWriter::~SimpleWriter() = default; +ProtobufWriter::~ProtobufWriter() = default; -void ProtobufWriter::SimpleWriter::startMessage() +void ProtobufWriter::startMessage() { } -void ProtobufWriter::SimpleWriter::endMessage() +void ProtobufWriter::endMessage(bool with_length_delimiter) { pieces.emplace_back(current_piece_start, buffer.size()); - if (use_length_delimiters) + if (with_length_delimiter) { size_t size_of_message = buffer.size() - num_bytes_skipped; writeVarint(size_of_message, out); @@ -154,7 +94,7 @@ void ProtobufWriter::SimpleWriter::endMessage() current_piece_start = 0; } -void ProtobufWriter::SimpleWriter::startNestedMessage() +void ProtobufWriter::startNestedMessage() { nested_infos.emplace_back(pieces.size(), num_bytes_skipped); pieces.emplace_back(current_piece_start, buffer.size()); @@ -167,7 +107,7 @@ void ProtobufWriter::SimpleWriter::startNestedMessage() num_bytes_skipped = NESTED_MESSAGE_PADDING; } -void ProtobufWriter::SimpleWriter::endNestedMessage(UInt32 field_number, bool is_group, bool skip_if_empty) +void ProtobufWriter::endNestedMessage(int field_number, bool is_group, bool skip_if_empty) { const auto & nested_info = nested_infos.back(); size_t num_pieces_at_start = nested_info.num_pieces_at_start; @@ -203,8 +143,13 @@ void ProtobufWriter::SimpleWriter::endNestedMessage(UInt32 field_number, bool is num_bytes_skipped += num_bytes_skipped_at_start - num_bytes_inserted; } -void ProtobufWriter::SimpleWriter::writeUInt(UInt32 field_number, UInt64 value) +void ProtobufWriter::writeUInt(int field_number, UInt64 value) { + if (in_repeated_pack) + { + writeVarint(value, buffer); + return; + } size_t old_size = buffer.size(); buffer.reserve(old_size + 2 * MAX_VARINT_SIZE); UInt8 * ptr = buffer.data() + old_size; @@ -213,20 +158,27 @@ void ProtobufWriter::SimpleWriter::writeUInt(UInt32 field_number, UInt64 value) buffer.resize_assume_reserved(ptr - buffer.data()); } -void ProtobufWriter::SimpleWriter::writeInt(UInt32 field_number, Int64 value) +void ProtobufWriter::writeInt(int field_number, Int64 value) { writeUInt(field_number, static_cast(value)); } -void ProtobufWriter::SimpleWriter::writeSInt(UInt32 field_number, Int64 value) +void ProtobufWriter::writeSInt(int field_number, Int64 value) { writeUInt(field_number, encodeZigZag(value)); } template -void ProtobufWriter::SimpleWriter::writeFixed(UInt32 field_number, T value) +void ProtobufWriter::writeFixed(int field_number, T value) { static_assert((sizeof(T) == 4) || (sizeof(T) == 8)); + if (in_repeated_pack) + { + size_t old_size = buffer.size(); + buffer.resize(old_size + sizeof(T)); + memcpy(buffer.data() + old_size, &value, sizeof(T)); + return; + } constexpr WireType wire_type = (sizeof(T) == 4) ? BITS32 : BITS64; size_t old_size = buffer.size(); buffer.reserve(old_size + MAX_VARINT_SIZE + sizeof(T)); @@ -237,19 +189,27 @@ void ProtobufWriter::SimpleWriter::writeFixed(UInt32 field_number, T value) buffer.resize_assume_reserved(ptr - buffer.data()); } -void ProtobufWriter::SimpleWriter::writeString(UInt32 field_number, const StringRef & str) +template void ProtobufWriter::writeFixed(int field_number, Int32 value); +template void ProtobufWriter::writeFixed(int field_number, UInt32 value); +template void ProtobufWriter::writeFixed(int field_number, Int64 value); +template void ProtobufWriter::writeFixed(int field_number, UInt64 value); +template void ProtobufWriter::writeFixed(int field_number, Float32 value); +template void ProtobufWriter::writeFixed(int field_number, Float64 value); + +void ProtobufWriter::writeString(int field_number, const std::string_view & str) { + size_t length = str.length(); size_t old_size = buffer.size(); - buffer.reserve(old_size + 2 * MAX_VARINT_SIZE + str.size); + buffer.reserve(old_size + 2 * MAX_VARINT_SIZE + length); UInt8 * ptr = buffer.data() + old_size; ptr = writeFieldNumber(field_number, LENGTH_DELIMITED, ptr); - ptr = writeVarint(str.size, ptr); - memcpy(ptr, str.data, str.size); - ptr += str.size; + ptr = writeVarint(length, ptr); + memcpy(ptr, str.data(), length); + ptr += length; buffer.resize_assume_reserved(ptr - buffer.data()); } -void ProtobufWriter::SimpleWriter::startRepeatedPack() +void ProtobufWriter::startRepeatedPack() { pieces.emplace_back(current_piece_start, buffer.size()); @@ -259,17 +219,19 @@ void ProtobufWriter::SimpleWriter::startRepeatedPack() current_piece_start = buffer.size() + REPEATED_PACK_PADDING; buffer.resize(current_piece_start); num_bytes_skipped += REPEATED_PACK_PADDING; + in_repeated_pack = true; } -void ProtobufWriter::SimpleWriter::endRepeatedPack(UInt32 field_number) +void ProtobufWriter::endRepeatedPack(int field_number, bool skip_if_empty) { size_t size = buffer.size() - current_piece_start; - if (!size) + if (!size && skip_if_empty) { current_piece_start = pieces.back().start; buffer.resize(pieces.back().end); pieces.pop_back(); num_bytes_skipped -= REPEATED_PACK_PADDING; + in_repeated_pack = false; return; } UInt8 * ptr = &buffer[pieces.back().end]; @@ -278,726 +240,7 @@ void ProtobufWriter::SimpleWriter::endRepeatedPack(UInt32 field_number) size_t num_bytes_inserted = endptr - ptr; pieces.back().end += num_bytes_inserted; num_bytes_skipped -= num_bytes_inserted; -} - -void ProtobufWriter::SimpleWriter::addUIntToRepeatedPack(UInt64 value) -{ - writeVarint(value, buffer); -} - -void ProtobufWriter::SimpleWriter::addIntToRepeatedPack(Int64 value) -{ - writeVarint(static_cast(value), buffer); -} - -void ProtobufWriter::SimpleWriter::addSIntToRepeatedPack(Int64 value) -{ - writeVarint(encodeZigZag(value), buffer); -} - -template -void ProtobufWriter::SimpleWriter::addFixedToRepeatedPack(T value) -{ - static_assert((sizeof(T) == 4) || (sizeof(T) == 8)); - size_t old_size = buffer.size(); - buffer.resize(old_size + sizeof(T)); - memcpy(buffer.data() + old_size, &value, sizeof(T)); -} - - -// Implementation for a converter from any DB data type to any protobuf field type. -class ProtobufWriter::ConverterBaseImpl : public IConverter -{ -public: - ConverterBaseImpl(SimpleWriter & simple_writer_, const google::protobuf::FieldDescriptor * field_) - : simple_writer(simple_writer_), field(field_) - { - field_number = field->number(); - } - - virtual void writeString(const StringRef &) override { cannotConvertType("String"); } - virtual void writeInt8(Int8) override { cannotConvertType("Int8"); } - virtual void writeUInt8(UInt8) override { cannotConvertType("UInt8"); } - virtual void writeInt16(Int16) override { cannotConvertType("Int16"); } - virtual void writeUInt16(UInt16) override { cannotConvertType("UInt16"); } - virtual void writeInt32(Int32) override { cannotConvertType("Int32"); } - virtual void writeUInt32(UInt32) override { cannotConvertType("UInt32"); } - virtual void writeInt64(Int64) override { cannotConvertType("Int64"); } - virtual void writeUInt64(UInt64) override { cannotConvertType("UInt64"); } - virtual void writeInt128(Int128) override { cannotConvertType("Int128"); } - virtual void writeUInt128(const UInt128 &) override { cannotConvertType("UInt128"); } - virtual void writeInt256(const Int256 &) override { cannotConvertType("Int256"); } - virtual void writeUInt256(const UInt256 &) override { cannotConvertType("UInt256"); } - virtual void writeFloat32(Float32) override { cannotConvertType("Float32"); } - virtual void writeFloat64(Float64) override { cannotConvertType("Float64"); } - virtual void prepareEnumMapping8(const std::vector> &) override {} - virtual void prepareEnumMapping16(const std::vector> &) override {} - virtual void writeEnum8(Int8) override { cannotConvertType("Enum"); } - virtual void writeEnum16(Int16) override { cannotConvertType("Enum"); } - virtual void writeUUID(const UUID &) override { cannotConvertType("UUID"); } - virtual void writeDate(DayNum) override { cannotConvertType("Date"); } - virtual void writeDateTime(time_t) override { cannotConvertType("DateTime"); } - virtual void writeDateTime64(DateTime64, UInt32) override { cannotConvertType("DateTime64"); } - virtual void writeDecimal32(Decimal32, UInt32) override { cannotConvertType("Decimal32"); } - virtual void writeDecimal64(Decimal64, UInt32) override { cannotConvertType("Decimal64"); } - virtual void writeDecimal128(const Decimal128 &, UInt32) override { cannotConvertType("Decimal128"); } - virtual void writeDecimal256(const Decimal256 &, UInt32) override { cannotConvertType("Decimal256"); } - - virtual void writeAggregateFunction(const AggregateFunctionPtr &, ConstAggregateDataPtr) override { cannotConvertType("AggregateFunction"); } - -protected: - [[noreturn]] void cannotConvertType(const String & type_name) - { - throw Exception( - "Could not convert data type '" + type_name + "' to protobuf type '" + field->type_name() + "' (field: " + field->name() + ")", - ErrorCodes::PROTOBUF_BAD_CAST); - } - - [[noreturn]] void cannotConvertValue(const String & value) - { - throw Exception( - "Could not convert value '" + value + "' to protobuf type '" + field->type_name() + "' (field: " + field->name() + ")", - ErrorCodes::PROTOBUF_BAD_CAST); - } - - template - To numericCast(From value) - { - if constexpr (std::is_same_v) - return value; - To result; - try - { - result = boost::numeric_cast(value); - } - catch (boost::numeric::bad_numeric_cast &) - { - cannotConvertValue(toString(value)); - } - return result; - } - - template - To parseFromString(const StringRef & str) - { - To result; - try - { - result = ::DB::parse(str.data, str.size); - } - catch (...) - { - cannotConvertValue(str.toString()); - } - return result; - } - - SimpleWriter & simple_writer; - const google::protobuf::FieldDescriptor * field; - UInt32 field_number; -}; - - -template -class ProtobufWriter::ConverterToString : public ConverterBaseImpl -{ -public: - using ConverterBaseImpl::ConverterBaseImpl; - - void writeString(const StringRef & str) override { writeField(str); } - - void writeInt8(Int8 value) override { convertToStringAndWriteField(value); } - void writeUInt8(UInt8 value) override { convertToStringAndWriteField(value); } - void writeInt16(Int16 value) override { convertToStringAndWriteField(value); } - void writeUInt16(UInt16 value) override { convertToStringAndWriteField(value); } - void writeInt32(Int32 value) override { convertToStringAndWriteField(value); } - void writeUInt32(UInt32 value) override { convertToStringAndWriteField(value); } - void writeInt64(Int64 value) override { convertToStringAndWriteField(value); } - void writeUInt64(UInt64 value) override { convertToStringAndWriteField(value); } - void writeFloat32(Float32 value) override { convertToStringAndWriteField(value); } - void writeFloat64(Float64 value) override { convertToStringAndWriteField(value); } - - void prepareEnumMapping8(const std::vector> & name_value_pairs) override - { - prepareEnumValueToNameMap(name_value_pairs); - } - void prepareEnumMapping16(const std::vector> & name_value_pairs) override - { - prepareEnumValueToNameMap(name_value_pairs); - } - - void writeEnum8(Int8 value) override { writeEnum16(value); } - - void writeEnum16(Int16 value) override - { - auto it = enum_value_to_name_map->find(value); - if (it == enum_value_to_name_map->end()) - cannotConvertValue(toString(value)); - writeField(it->second); - } - - void writeUUID(const UUID & uuid) override { convertToStringAndWriteField(uuid); } - void writeDate(DayNum date) override { convertToStringAndWriteField(date); } - - void writeDateTime(time_t tm) override - { - writeDateTimeText(tm, text_buffer); - writeField(text_buffer.stringRef()); - text_buffer.restart(); - } - - void writeDateTime64(DateTime64 date_time, UInt32 scale) override - { - writeDateTimeText(date_time, scale, text_buffer); - writeField(text_buffer.stringRef()); - text_buffer.restart(); - } - - void writeDecimal32(Decimal32 decimal, UInt32 scale) override { writeDecimal(decimal, scale); } - void writeDecimal64(Decimal64 decimal, UInt32 scale) override { writeDecimal(decimal, scale); } - void writeDecimal128(const Decimal128 & decimal, UInt32 scale) override { writeDecimal(decimal, scale); } - - void writeAggregateFunction(const AggregateFunctionPtr & function, ConstAggregateDataPtr place) override - { - function->serialize(place, text_buffer); - writeField(text_buffer.stringRef()); - text_buffer.restart(); - } - -private: - template - void convertToStringAndWriteField(T value) - { - writeText(value, text_buffer); - writeField(text_buffer.stringRef()); - text_buffer.restart(); - } - - template - void writeDecimal(const Decimal & decimal, UInt32 scale) - { - writeText(decimal, scale, text_buffer); - writeField(text_buffer.stringRef()); - text_buffer.restart(); - } - - template - void prepareEnumValueToNameMap(const std::vector> & name_value_pairs) - { - if (enum_value_to_name_map.has_value()) - return; - enum_value_to_name_map.emplace(); - for (const auto & name_value_pair : name_value_pairs) - enum_value_to_name_map->emplace(name_value_pair.second, name_value_pair.first); - } - - void writeField(const StringRef & str) - { - if constexpr (skip_null_value) - { - if (!str.size) - return; - } - simple_writer.writeString(field_number, str); - } - - WriteBufferFromOwnString text_buffer; - std::optional> enum_value_to_name_map; -}; - -# define PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS(field_type_id) \ - template <> \ - std::unique_ptr ProtobufWriter::createConverter( \ - const google::protobuf::FieldDescriptor * field) \ - { \ - if (shouldSkipNullValue(field)) \ - return std::make_unique>(simple_writer, field); \ - else \ - return std::make_unique>(simple_writer, field); \ - } -PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS(google::protobuf::FieldDescriptor::TYPE_STRING) -PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS(google::protobuf::FieldDescriptor::TYPE_BYTES) -# undef PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_STRINGS - - -template -class ProtobufWriter::ConverterToNumber : public ConverterBaseImpl -{ -public: - using ConverterBaseImpl::ConverterBaseImpl; - - void writeString(const StringRef & str) override { writeField(parseFromString(str)); } - - void writeInt8(Int8 value) override { castNumericAndWriteField(value); } - void writeUInt8(UInt8 value) override { castNumericAndWriteField(value); } - void writeInt16(Int16 value) override { castNumericAndWriteField(value); } - void writeUInt16(UInt16 value) override { castNumericAndWriteField(value); } - void writeInt32(Int32 value) override { castNumericAndWriteField(value); } - void writeUInt32(UInt32 value) override { castNumericAndWriteField(value); } - void writeInt64(Int64 value) override { castNumericAndWriteField(value); } - void writeUInt64(UInt64 value) override { castNumericAndWriteField(value); } - void writeFloat32(Float32 value) override { castNumericAndWriteField(value); } - void writeFloat64(Float64 value) override { castNumericAndWriteField(value); } - - void writeEnum8(Int8 value) override { writeEnum16(value); } - - void writeEnum16(Int16 value) override - { - if constexpr (!is_integer_v) - cannotConvertType("Enum"); // It's not correct to convert enum to floating point. - castNumericAndWriteField(value); - } - - void writeDate(DayNum date) override { castNumericAndWriteField(static_cast(date)); } - void writeDateTime(time_t tm) override { castNumericAndWriteField(tm); } - void writeDateTime64(DateTime64 date_time, UInt32 scale) override { writeDecimal(date_time, scale); } - void writeDecimal32(Decimal32 decimal, UInt32 scale) override { writeDecimal(decimal, scale); } - void writeDecimal64(Decimal64 decimal, UInt32 scale) override { writeDecimal(decimal, scale); } - void writeDecimal128(const Decimal128 & decimal, UInt32 scale) override { writeDecimal(decimal, scale); } - -private: - template - void castNumericAndWriteField(FromType value) - { - writeField(numericCast(value)); - } - - template - void writeDecimal(const Decimal & decimal, UInt32 scale) - { - castNumericAndWriteField(DecimalUtils::convertTo(decimal, scale)); - } - - void writeField(ToType value) - { - if constexpr (skip_null_value) - { - if (value == 0) - return; - } - if constexpr (((field_type_id == google::protobuf::FieldDescriptor::TYPE_INT32) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_INT64) && std::is_same_v)) - { - if constexpr (pack_repeated) - simple_writer.addIntToRepeatedPack(value); - else - simple_writer.writeInt(field_number, value); - } - else if constexpr (((field_type_id == google::protobuf::FieldDescriptor::TYPE_SINT32) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_SINT64) && std::is_same_v)) - { - if constexpr (pack_repeated) - simple_writer.addSIntToRepeatedPack(value); - else - simple_writer.writeSInt(field_number, value); - } - else if constexpr (((field_type_id == google::protobuf::FieldDescriptor::TYPE_UINT32) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_UINT64) && std::is_same_v)) - { - if constexpr (pack_repeated) - simple_writer.addUIntToRepeatedPack(value); - else - simple_writer.writeUInt(field_number, value); - } - else - { - static_assert(((field_type_id == google::protobuf::FieldDescriptor::TYPE_FIXED32) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_SFIXED32) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_FIXED64) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_SFIXED64) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_FLOAT) && std::is_same_v) - || ((field_type_id == google::protobuf::FieldDescriptor::TYPE_DOUBLE) && std::is_same_v)); - if constexpr (pack_repeated) - simple_writer.addFixedToRepeatedPack(value); - else - simple_writer.writeFixed(field_number, value); - } - } -}; - -# define PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(field_type_id, field_type) \ - template <> \ - std::unique_ptr ProtobufWriter::createConverter( \ - const google::protobuf::FieldDescriptor * field) \ - { \ - if (shouldSkipNullValue(field)) \ - return std::make_unique>(simple_writer, field); \ - else if (shouldPackRepeated(field)) \ - return std::make_unique>(simple_writer, field); \ - else \ - return std::make_unique>(simple_writer, field); \ - } - -PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_INT32, Int32); -PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SINT32, Int32); -PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_UINT32, UInt32); -PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_INT64, Int64); -PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SINT64, Int64); -PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_UINT64, UInt64); -PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_FIXED32, UInt32); -PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SFIXED32, Int32); -PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_FIXED64, UInt64); -PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_SFIXED64, Int64); -PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_FLOAT, float); -PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS(google::protobuf::FieldDescriptor::TYPE_DOUBLE, double); -# undef PROTOBUF_WRITER_CREATE_CONVERTER_SPECIALIZATION_FOR_NUMBERS - - -template -class ProtobufWriter::ConverterToBool : public ConverterBaseImpl -{ -public: - using ConverterBaseImpl::ConverterBaseImpl; - - void writeString(const StringRef & str) override - { - if (str == "true") - writeField(true); - else if (str == "false") - writeField(false); - else - cannotConvertValue(str.toString()); - } - - void writeInt8(Int8 value) override { convertToBoolAndWriteField(value); } - void writeUInt8(UInt8 value) override { convertToBoolAndWriteField(value); } - void writeInt16(Int16 value) override { convertToBoolAndWriteField(value); } - void writeUInt16(UInt16 value) override { convertToBoolAndWriteField(value); } - void writeInt32(Int32 value) override { convertToBoolAndWriteField(value); } - void writeUInt32(UInt32 value) override { convertToBoolAndWriteField(value); } - void writeInt64(Int64 value) override { convertToBoolAndWriteField(value); } - void writeUInt64(UInt64 value) override { convertToBoolAndWriteField(value); } - void writeFloat32(Float32 value) override { convertToBoolAndWriteField(value); } - void writeFloat64(Float64 value) override { convertToBoolAndWriteField(value); } - void writeDecimal32(Decimal32 decimal, UInt32) override { convertToBoolAndWriteField(decimal.value); } - void writeDecimal64(Decimal64 decimal, UInt32) override { convertToBoolAndWriteField(decimal.value); } - void writeDecimal128(const Decimal128 & decimal, UInt32) override { convertToBoolAndWriteField(decimal.value); } - -private: - template - void convertToBoolAndWriteField(T value) - { - writeField(static_cast(value)); - } - - void writeField(bool b) - { - if constexpr (skip_null_value) - { - if (!b) - return; - } - if constexpr (pack_repeated) - simple_writer.addUIntToRepeatedPack(b); - else - simple_writer.writeUInt(field_number, b); - } -}; - -template <> -std::unique_ptr ProtobufWriter::createConverter( - const google::protobuf::FieldDescriptor * field) -{ - if (shouldSkipNullValue(field)) - return std::make_unique>(simple_writer, field); - else if (shouldPackRepeated(field)) - return std::make_unique>(simple_writer, field); - else - return std::make_unique>(simple_writer, field); -} - - -template -class ProtobufWriter::ConverterToEnum : public ConverterBaseImpl -{ -public: - using ConverterBaseImpl::ConverterBaseImpl; - - void writeString(const StringRef & str) override - { - prepareEnumNameToPbNumberMap(); - auto it = enum_name_to_pbnumber_map->find(str); - if (it == enum_name_to_pbnumber_map->end()) - cannotConvertValue(str.toString()); - writeField(it->second); - } - - void writeInt8(Int8 value) override { convertToEnumAndWriteField(value); } - void writeUInt8(UInt8 value) override { convertToEnumAndWriteField(value); } - void writeInt16(Int16 value) override { convertToEnumAndWriteField(value); } - void writeUInt16(UInt16 value) override { convertToEnumAndWriteField(value); } - void writeInt32(Int32 value) override { convertToEnumAndWriteField(value); } - void writeUInt32(UInt32 value) override { convertToEnumAndWriteField(value); } - void writeInt64(Int64 value) override { convertToEnumAndWriteField(value); } - void writeUInt64(UInt64 value) override { convertToEnumAndWriteField(value); } - - void prepareEnumMapping8(const std::vector> & name_value_pairs) override - { - prepareEnumValueToPbNumberMap(name_value_pairs); - } - void prepareEnumMapping16(const std::vector> & name_value_pairs) override - { - prepareEnumValueToPbNumberMap(name_value_pairs); - } - - void writeEnum8(Int8 value) override { writeEnum16(value); } - - void writeEnum16(Int16 value) override - { - int pbnumber; - if (enum_value_always_equals_pbnumber) - pbnumber = value; - else - { - auto it = enum_value_to_pbnumber_map->find(value); - if (it == enum_value_to_pbnumber_map->end()) - cannotConvertValue(toString(value)); - pbnumber = it->second; - } - writeField(pbnumber); - } - -private: - template - void convertToEnumAndWriteField(T value) - { - const auto * enum_descriptor = field->enum_type()->FindValueByNumber(numericCast(value)); - if (!enum_descriptor) - cannotConvertValue(toString(value)); - writeField(enum_descriptor->number()); - } - - void prepareEnumNameToPbNumberMap() - { - if (enum_name_to_pbnumber_map.has_value()) - return; - enum_name_to_pbnumber_map.emplace(); - const auto * enum_type = field->enum_type(); - for (int i = 0; i != enum_type->value_count(); ++i) - { - const auto * enum_value = enum_type->value(i); - enum_name_to_pbnumber_map->emplace(enum_value->name(), enum_value->number()); - } - } - - template - void prepareEnumValueToPbNumberMap(const std::vector> & name_value_pairs) - { - if (enum_value_to_pbnumber_map.has_value()) - return; - enum_value_to_pbnumber_map.emplace(); - enum_value_always_equals_pbnumber = true; - for (const auto & name_value_pair : name_value_pairs) - { - Int16 value = name_value_pair.second; // NOLINT - const auto * enum_descriptor = field->enum_type()->FindValueByName(name_value_pair.first); - if (enum_descriptor) - { - enum_value_to_pbnumber_map->emplace(value, enum_descriptor->number()); - if (value != enum_descriptor->number()) - enum_value_always_equals_pbnumber = false; - } - else - enum_value_always_equals_pbnumber = false; - } - } - - void writeField(int enum_pbnumber) - { - if constexpr (skip_null_value) - { - if (!enum_pbnumber) - return; - } - if constexpr (pack_repeated) - simple_writer.addUIntToRepeatedPack(enum_pbnumber); - else - simple_writer.writeUInt(field_number, enum_pbnumber); - } - - std::optional> enum_name_to_pbnumber_map; - std::optional> enum_value_to_pbnumber_map; - bool enum_value_always_equals_pbnumber; -}; - -template <> -std::unique_ptr ProtobufWriter::createConverter( - const google::protobuf::FieldDescriptor * field) -{ - if (shouldSkipNullValue(field)) - return std::make_unique>(simple_writer, field); - else if (shouldPackRepeated(field)) - return std::make_unique>(simple_writer, field); - else - return std::make_unique>(simple_writer, field); -} - - -ProtobufWriter::ProtobufWriter( - WriteBuffer & out, const google::protobuf::Descriptor * message_type, const std::vector & column_names, const bool use_length_delimiters_) - : simple_writer(out, use_length_delimiters_) -{ - std::vector field_descriptors_without_match; - root_message = ProtobufColumnMatcher::matchColumns(column_names, message_type, field_descriptors_without_match); - for (const auto * field_descriptor_without_match : field_descriptors_without_match) - { - if (field_descriptor_without_match->is_required()) - throw Exception( - "Output doesn't have a column named '" + field_descriptor_without_match->name() - + "' which is required to write the output in the protobuf format.", - ErrorCodes::NO_DATA_FOR_REQUIRED_PROTOBUF_FIELD); - } - setTraitsDataAfterMatchingColumns(root_message.get()); -} - -ProtobufWriter::~ProtobufWriter() = default; - -void ProtobufWriter::setTraitsDataAfterMatchingColumns(Message * message) -{ - Field * parent_field = message->parent ? &message->parent->fields[message->index_in_parent] : nullptr; - message->data.parent_field_number = parent_field ? parent_field->field_number : 0; - message->data.is_required = parent_field && parent_field->data.is_required; - - if (parent_field && parent_field->data.is_repeatable) - message->data.repeatable_container_message = message; - else if (message->parent) - message->data.repeatable_container_message = message->parent->data.repeatable_container_message; - else - message->data.repeatable_container_message = nullptr; - - message->data.is_group = parent_field && (parent_field->field_descriptor->type() == google::protobuf::FieldDescriptor::TYPE_GROUP); - - for (auto & field : message->fields) - { - field.data.is_repeatable = field.field_descriptor->is_repeated(); - field.data.is_required = field.field_descriptor->is_required(); - field.data.repeatable_container_message = message->data.repeatable_container_message; - field.data.should_pack_repeated = shouldPackRepeated(field.field_descriptor); - - if (field.nested_message) - { - setTraitsDataAfterMatchingColumns(field.nested_message.get()); - continue; - } - switch (field.field_descriptor->type()) - { -# define PROTOBUF_WRITER_CONVERTER_CREATING_CASE(field_type_id) \ - case field_type_id: \ - field.data.converter = createConverter(field.field_descriptor); \ - break - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_STRING); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_BYTES); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_INT32); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SINT32); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_UINT32); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_FIXED32); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SFIXED32); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_INT64); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SINT64); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_UINT64); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_FIXED64); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_SFIXED64); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_FLOAT); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_DOUBLE); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_BOOL); - PROTOBUF_WRITER_CONVERTER_CREATING_CASE(google::protobuf::FieldDescriptor::TYPE_ENUM); -# undef PROTOBUF_WRITER_CONVERTER_CREATING_CASE - default: - throw Exception( - String("Protobuf type '") + field.field_descriptor->type_name() + "' isn't supported", ErrorCodes::NOT_IMPLEMENTED); - } - } -} - -void ProtobufWriter::startMessage() -{ - current_message = root_message.get(); - current_field_index = 0; - simple_writer.startMessage(); -} - -void ProtobufWriter::endMessage() -{ - if (!current_message) - return; - endWritingField(); - while (current_message->parent) - { - simple_writer.endNestedMessage( - current_message->data.parent_field_number, current_message->data.is_group, !current_message->data.is_required); - current_message = current_message->parent; - } - simple_writer.endMessage(); - current_message = nullptr; -} - -bool ProtobufWriter::writeField(size_t & column_index) -{ - endWritingField(); - while (true) - { - if (current_field_index < current_message->fields.size()) - { - Field & field = current_message->fields[current_field_index]; - if (!field.nested_message) - { - current_field = ¤t_message->fields[current_field_index]; - current_converter = current_field->data.converter.get(); - column_index = current_field->column_index; - if (current_field->data.should_pack_repeated) - simple_writer.startRepeatedPack(); - return true; - } - simple_writer.startNestedMessage(); - current_message = field.nested_message.get(); - current_message->data.need_repeat = false; - current_field_index = 0; - continue; - } - if (current_message->parent) - { - simple_writer.endNestedMessage( - current_message->data.parent_field_number, current_message->data.is_group, !current_message->data.is_required); - if (current_message->data.need_repeat) - { - simple_writer.startNestedMessage(); - current_message->data.need_repeat = false; - current_field_index = 0; - continue; - } - current_field_index = current_message->index_in_parent + 1; - current_message = current_message->parent; - continue; - } - return false; - } -} - -void ProtobufWriter::endWritingField() -{ - if (!current_field) - return; - if (current_field->data.should_pack_repeated) - simple_writer.endRepeatedPack(current_field->field_number); - else if ((num_values == 0) && current_field->data.is_required) - throw Exception( - "No data for the required field '" + current_field->field_descriptor->name() + "'", - ErrorCodes::NO_DATA_FOR_REQUIRED_PROTOBUF_FIELD); - - current_field = nullptr; - current_converter = nullptr; - num_values = 0; - ++current_field_index; -} - -void ProtobufWriter::setNestedMessageNeedsRepeat() -{ - if (current_field->data.repeatable_container_message) - current_field->data.repeatable_container_message->data.need_repeat = true; - else - throw Exception( - "Cannot write more than single value to the non-repeated field '" + current_field->field_descriptor->name() + "'", - ErrorCodes::PROTOBUF_FIELD_NOT_REPEATED); + in_repeated_pack = false; } } diff --git a/src/Formats/ProtobufWriter.h b/src/Formats/ProtobufWriter.h index 52bb453aa73..6af1a237fbd 100644 --- a/src/Formats/ProtobufWriter.h +++ b/src/Formats/ProtobufWriter.h @@ -1,290 +1,68 @@ #pragma once -#include -#include -#include - #if !defined(ARCADIA_BUILD) # include "config_formats.h" #endif #if USE_PROTOBUF -# include -# include -# include -# include "ProtobufColumnMatcher.h" - - -namespace google -{ -namespace protobuf -{ - class Descriptor; - class FieldDescriptor; -} -} - -namespace DB -{ -class IAggregateFunction; -using AggregateFunctionPtr = std::shared_ptr; -using ConstAggregateDataPtr = const char *; - - -/** Serializes a protobuf, tries to cast types if necessarily. - */ -class ProtobufWriter : private boost::noncopyable -{ -public: - ProtobufWriter(WriteBuffer & out, const google::protobuf::Descriptor * message_type, const std::vector & column_names, const bool use_length_delimiters_); - ~ProtobufWriter(); - - /// Should be called at the beginning of writing a message. - void startMessage(); - - /// Should be called at the end of writing a message. - void endMessage(); - - /// Prepares for writing values of a field. - /// Returns true and sets 'column_index' to the corresponding column's index. - /// Returns false if there are no more fields to write in the message type (call endMessage() in this case). - bool writeField(size_t & column_index); - - /// Writes a value. This function should be called one or multiple times after writeField(). - /// Returns false if there are no more place for the values in the protobuf's field. - /// This can happen if the protobuf's field is not declared as repeated in the protobuf schema. - bool writeNumber(Int8 value) { return writeValueIfPossible(&IConverter::writeInt8, value); } - bool writeNumber(UInt8 value) { return writeValueIfPossible(&IConverter::writeUInt8, value); } - bool writeNumber(Int16 value) { return writeValueIfPossible(&IConverter::writeInt16, value); } - bool writeNumber(UInt16 value) { return writeValueIfPossible(&IConverter::writeUInt16, value); } - bool writeNumber(Int32 value) { return writeValueIfPossible(&IConverter::writeInt32, value); } - bool writeNumber(UInt32 value) { return writeValueIfPossible(&IConverter::writeUInt32, value); } - bool writeNumber(Int64 value) { return writeValueIfPossible(&IConverter::writeInt64, value); } - bool writeNumber(UInt64 value) { return writeValueIfPossible(&IConverter::writeUInt64, value); } - bool writeNumber(Int128 value) { return writeValueIfPossible(&IConverter::writeInt128, value); } - bool writeNumber(UInt128 value) { return writeValueIfPossible(&IConverter::writeUInt128, value); } - - bool writeNumber(Int256 value) { return writeValueIfPossible(&IConverter::writeInt256, value); } - bool writeNumber(UInt256 value) { return writeValueIfPossible(&IConverter::writeUInt256, value); } - - bool writeNumber(Float32 value) { return writeValueIfPossible(&IConverter::writeFloat32, value); } - bool writeNumber(Float64 value) { return writeValueIfPossible(&IConverter::writeFloat64, value); } - bool writeString(const StringRef & str) { return writeValueIfPossible(&IConverter::writeString, str); } - void prepareEnumMapping(const std::vector> & enum_values) { current_converter->prepareEnumMapping8(enum_values); } - void prepareEnumMapping(const std::vector> & enum_values) { current_converter->prepareEnumMapping16(enum_values); } - bool writeEnum(Int8 value) { return writeValueIfPossible(&IConverter::writeEnum8, value); } - bool writeEnum(Int16 value) { return writeValueIfPossible(&IConverter::writeEnum16, value); } - bool writeUUID(const UUID & uuid) { return writeValueIfPossible(&IConverter::writeUUID, uuid); } - bool writeDate(DayNum date) { return writeValueIfPossible(&IConverter::writeDate, date); } - bool writeDateTime(time_t tm) { return writeValueIfPossible(&IConverter::writeDateTime, tm); } - bool writeDateTime64(DateTime64 tm, UInt32 scale) { return writeValueIfPossible(&IConverter::writeDateTime64, tm, scale); } - bool writeDecimal(Decimal32 decimal, UInt32 scale) { return writeValueIfPossible(&IConverter::writeDecimal32, decimal, scale); } - bool writeDecimal(Decimal64 decimal, UInt32 scale) { return writeValueIfPossible(&IConverter::writeDecimal64, decimal, scale); } - bool writeDecimal(const Decimal128 & decimal, UInt32 scale) { return writeValueIfPossible(&IConverter::writeDecimal128, decimal, scale); } - bool writeDecimal(const Decimal256 & decimal, UInt32 scale) { return writeValueIfPossible(&IConverter::writeDecimal256, decimal, scale); } - bool writeAggregateFunction(const AggregateFunctionPtr & function, ConstAggregateDataPtr place) { return writeValueIfPossible(&IConverter::writeAggregateFunction, function, place); } - -private: - class SimpleWriter - { - public: - SimpleWriter(WriteBuffer & out_, const bool use_length_delimiters_); - ~SimpleWriter(); - - void startMessage(); - void endMessage(); - - void startNestedMessage(); - void endNestedMessage(UInt32 field_number, bool is_group, bool skip_if_empty); - - void writeInt(UInt32 field_number, Int64 value); - void writeUInt(UInt32 field_number, UInt64 value); - void writeSInt(UInt32 field_number, Int64 value); - template - void writeFixed(UInt32 field_number, T value); - void writeString(UInt32 field_number, const StringRef & str); - - void startRepeatedPack(); - void addIntToRepeatedPack(Int64 value); - void addUIntToRepeatedPack(UInt64 value); - void addSIntToRepeatedPack(Int64 value); - template - void addFixedToRepeatedPack(T value); - void endRepeatedPack(UInt32 field_number); - - private: - struct Piece - { - size_t start; - size_t end; - Piece(size_t start_, size_t end_) : start(start_), end(end_) {} - Piece() = default; - }; - - struct NestedInfo - { - size_t num_pieces_at_start; - size_t num_bytes_skipped_at_start; - NestedInfo(size_t num_pieces_at_start_, size_t num_bytes_skipped_at_start_) - : num_pieces_at_start(num_pieces_at_start_), num_bytes_skipped_at_start(num_bytes_skipped_at_start_) - { - } - }; - - WriteBuffer & out; - PODArray buffer; - std::vector pieces; - size_t current_piece_start; - size_t num_bytes_skipped; - std::vector nested_infos; - const bool use_length_delimiters; - }; - - class IConverter - { - public: - virtual ~IConverter() = default; - virtual void writeString(const StringRef &) = 0; - virtual void writeInt8(Int8) = 0; - virtual void writeUInt8(UInt8) = 0; - virtual void writeInt16(Int16) = 0; - virtual void writeUInt16(UInt16) = 0; - virtual void writeInt32(Int32) = 0; - virtual void writeUInt32(UInt32) = 0; - virtual void writeInt64(Int64) = 0; - virtual void writeUInt64(UInt64) = 0; - virtual void writeInt128(Int128) = 0; - virtual void writeUInt128(const UInt128 &) = 0; - - virtual void writeInt256(const Int256 &) = 0; - virtual void writeUInt256(const UInt256 &) = 0; - - virtual void writeFloat32(Float32) = 0; - virtual void writeFloat64(Float64) = 0; - virtual void prepareEnumMapping8(const std::vector> &) = 0; - virtual void prepareEnumMapping16(const std::vector> &) = 0; - virtual void writeEnum8(Int8) = 0; - virtual void writeEnum16(Int16) = 0; - virtual void writeUUID(const UUID &) = 0; - virtual void writeDate(DayNum) = 0; - virtual void writeDateTime(time_t) = 0; - virtual void writeDateTime64(DateTime64, UInt32 scale) = 0; - virtual void writeDecimal32(Decimal32, UInt32) = 0; - virtual void writeDecimal64(Decimal64, UInt32) = 0; - virtual void writeDecimal128(const Decimal128 &, UInt32) = 0; - virtual void writeDecimal256(const Decimal256 &, UInt32) = 0; - virtual void writeAggregateFunction(const AggregateFunctionPtr &, ConstAggregateDataPtr) = 0; - }; - - class ConverterBaseImpl; - template - class ConverterToString; - template - class ConverterToNumber; - template - class ConverterToBool; - template - class ConverterToEnum; - - struct ColumnMatcherTraits - { - struct FieldData - { - std::unique_ptr converter; - bool is_required; - bool is_repeatable; - bool should_pack_repeated; - ProtobufColumnMatcher::Message * repeatable_container_message; - }; - struct MessageData - { - UInt32 parent_field_number; - bool is_group; - bool is_required; - ProtobufColumnMatcher::Message * repeatable_container_message; - bool need_repeat; - }; - }; - using Message = ProtobufColumnMatcher::Message; - using Field = ProtobufColumnMatcher::Field; - - void setTraitsDataAfterMatchingColumns(Message * message); - - template - std::unique_ptr createConverter(const google::protobuf::FieldDescriptor * field); - - template - using WriteValueFunctionPtr = void (IConverter::*)(Params...); - - template - bool writeValueIfPossible(WriteValueFunctionPtr func, Args &&... args) - { - if (num_values && !current_field->data.is_repeatable) - { - setNestedMessageNeedsRepeat(); - return false; - } - (current_converter->*func)(std::forward(args)...); - ++num_values; - return true; - } - - void setNestedMessageNeedsRepeat(); - void endWritingField(); - - SimpleWriter simple_writer; - std::unique_ptr root_message; - - Message * current_message; - size_t current_field_index = 0; - const Field * current_field = nullptr; - IConverter * current_converter = nullptr; - size_t num_values = 0; -}; - -} - -#else -# include +# include +# include namespace DB { -class IAggregateFunction; -using AggregateFunctionPtr = std::shared_ptr; -using ConstAggregateDataPtr = const char *; +class WriteBuffer; +/// Utility class for writing in the Protobuf format. +/// Knows nothing about protobuf schemas, just provides useful functions to serialize data. class ProtobufWriter { public: - bool writeNumber(Int8 /* value */) { return false; } - bool writeNumber(UInt8 /* value */) { return false; } - bool writeNumber(Int16 /* value */) { return false; } - bool writeNumber(UInt16 /* value */) { return false; } - bool writeNumber(Int32 /* value */) { return false; } - bool writeNumber(UInt32 /* value */) { return false; } - bool writeNumber(Int64 /* value */) { return false; } - bool writeNumber(UInt64 /* value */) { return false; } - bool writeNumber(Int128 /* value */) { return false; } - bool writeNumber(UInt128 /* value */) { return false; } - bool writeNumber(Int256 /* value */) { return false; } - bool writeNumber(UInt256 /* value */) { return false; } - bool writeNumber(Float32 /* value */) { return false; } - bool writeNumber(Float64 /* value */) { return false; } - bool writeString(const StringRef & /* value */) { return false; } - void prepareEnumMapping(const std::vector> & /* name_value_pairs */) {} - void prepareEnumMapping(const std::vector> & /* name_value_pairs */) {} - bool writeEnum(Int8 /* value */) { return false; } - bool writeEnum(Int16 /* value */) { return false; } - bool writeUUID(const UUID & /* value */) { return false; } - bool writeDate(DayNum /* date */) { return false; } - bool writeDateTime(time_t /* tm */) { return false; } - bool writeDateTime64(DateTime64 /*tm*/, UInt32 /*scale*/) { return false; } - bool writeDecimal(Decimal32 /* decimal */, UInt32 /* scale */) { return false; } - bool writeDecimal(Decimal64 /* decimal */, UInt32 /* scale */) { return false; } - bool writeDecimal(const Decimal128 & /* decimal */, UInt32 /* scale */) { return false; } - bool writeDecimal(const Decimal256 & /* decimal */, UInt32 /* scale */) { return false; } - bool writeAggregateFunction(const AggregateFunctionPtr & /* function */, ConstAggregateDataPtr /* place */) { return false; } + ProtobufWriter(WriteBuffer & out_); + ~ProtobufWriter(); + + void startMessage(); + void endMessage(bool with_length_delimiter); + + void startNestedMessage(); + void endNestedMessage(int field_number, bool is_group, bool skip_if_empty); + + void writeInt(int field_number, Int64 value); + void writeUInt(int field_number, UInt64 value); + void writeSInt(int field_number, Int64 value); + template + void writeFixed(int field_number, T value); + void writeString(int field_number, const std::string_view & str); + + void startRepeatedPack(); + void endRepeatedPack(int field_number, bool skip_if_empty); + +private: + struct Piece + { + size_t start; + size_t end; + Piece(size_t start_, size_t end_) : start(start_), end(end_) {} + Piece() = default; + }; + + struct NestedInfo + { + size_t num_pieces_at_start; + size_t num_bytes_skipped_at_start; + NestedInfo(size_t num_pieces_at_start_, size_t num_bytes_skipped_at_start_) + : num_pieces_at_start(num_pieces_at_start_), num_bytes_skipped_at_start(num_bytes_skipped_at_start_) + { + } + }; + + WriteBuffer & out; + PODArray buffer; + std::vector pieces; + size_t current_piece_start = 0; + size_t num_bytes_skipped = 0; + std::vector nested_infos; + bool in_repeated_pack = false; }; } diff --git a/src/Formats/ya.make b/src/Formats/ya.make index 6b72ec397d5..8fe938be125 100644 --- a/src/Formats/ya.make +++ b/src/Formats/ya.make @@ -20,9 +20,9 @@ SRCS( NativeFormat.cpp NullFormat.cpp ParsedTemplateFormatString.cpp - ProtobufColumnMatcher.cpp ProtobufReader.cpp ProtobufSchemas.cpp + ProtobufSerializer.cpp ProtobufWriter.cpp registerFormats.cpp verbosePrintString.cpp diff --git a/src/Processors/Formats/Impl/ProtobufRowInputFormat.cpp b/src/Processors/Formats/Impl/ProtobufRowInputFormat.cpp index d1420d0d38e..22a758b80f6 100644 --- a/src/Processors/Formats/Impl/ProtobufRowInputFormat.cpp +++ b/src/Processors/Formats/Impl/ProtobufRowInputFormat.cpp @@ -1,57 +1,48 @@ #include "ProtobufRowInputFormat.h" #if USE_PROTOBUF -#include -#include -#include -#include -#include +# include +# include +# include +# include +# include +# include +# include +# include namespace DB { - -ProtobufRowInputFormat::ProtobufRowInputFormat(ReadBuffer & in_, const Block & header_, Params params_, const FormatSchemaInfo & info_, const bool use_length_delimiters_) +ProtobufRowInputFormat::ProtobufRowInputFormat(ReadBuffer & in_, const Block & header_, const Params & params_, const FormatSchemaInfo & schema_info_, bool with_length_delimiter_) : IRowInputFormat(header_, in_, params_) - , data_types(header_.getDataTypes()) - , reader(in, ProtobufSchemas::instance().getMessageTypeForFormatSchema(info_), header_.getNames(), use_length_delimiters_) + , reader(std::make_unique(in_)) + , serializer(ProtobufSerializer::create( + header_.getNames(), + header_.getDataTypes(), + missing_column_indices, + *ProtobufSchemas::instance().getMessageTypeForFormatSchema(schema_info_), + with_length_delimiter_, + *reader)) { } ProtobufRowInputFormat::~ProtobufRowInputFormat() = default; -bool ProtobufRowInputFormat::readRow(MutableColumns & columns, RowReadExtension & extra) +bool ProtobufRowInputFormat::readRow(MutableColumns & columns, RowReadExtension & row_read_extension) { - if (!reader.startMessage()) - return false; // EOF reached, no more messages. + if (reader->eof()) + return false; - // Set of columns for which the values were read. The rest will be filled with default values. - auto & read_columns = extra.read_columns; - read_columns.assign(columns.size(), false); + size_t row_num = columns.empty() ? 0 : columns[0]->size(); + if (!row_num) + serializer->setColumns(columns.data(), columns.size()); - // Read values from this message and put them to the columns while it's possible. - size_t column_index; - while (reader.readColumnIndex(column_index)) - { - bool allow_add_row = !static_cast(read_columns[column_index]); - do - { - bool row_added; - data_types[column_index]->deserializeProtobuf(*columns[column_index], reader, allow_add_row, row_added); - if (row_added) - { - read_columns[column_index] = true; - allow_add_row = false; - } - } while (reader.canReadMoreValues()); - } + serializer->readRow(row_num); - // Fill non-visited columns with the default values. - for (column_index = 0; column_index < read_columns.size(); ++column_index) - if (!read_columns[column_index]) - data_types[column_index]->insertDefaultInto(*columns[column_index]); - - reader.endMessage(); + row_read_extension.read_columns.clear(); + row_read_extension.read_columns.resize(columns.size(), true); + for (size_t column_idx : missing_column_indices) + row_read_extension.read_columns[column_idx] = false; return true; } @@ -62,14 +53,14 @@ bool ProtobufRowInputFormat::allowSyncAfterError() const void ProtobufRowInputFormat::syncAfterError() { - reader.endMessage(true); + reader->endMessage(true); } void registerInputFormatProcessorProtobuf(FormatFactory & factory) { - for (bool use_length_delimiters : {false, true}) + for (bool with_length_delimiter : {false, true}) { - factory.registerInputFormatProcessor(use_length_delimiters ? "Protobuf" : "ProtobufSingle", [use_length_delimiters]( + factory.registerInputFormatProcessor(with_length_delimiter ? "Protobuf" : "ProtobufSingle", [with_length_delimiter]( ReadBuffer & buf, const Block & sample, IRowInputFormat::Params params, @@ -78,7 +69,7 @@ void registerInputFormatProcessorProtobuf(FormatFactory & factory) return std::make_shared(buf, sample, std::move(params), FormatSchemaInfo(settings.schema.format_schema, "Protobuf", true, settings.schema.is_server, settings.schema.format_schema_path), - use_length_delimiters); + with_length_delimiter); }); } } diff --git a/src/Processors/Formats/Impl/ProtobufRowInputFormat.h b/src/Processors/Formats/Impl/ProtobufRowInputFormat.h index c6bc350e893..b2eabd4f37c 100644 --- a/src/Processors/Formats/Impl/ProtobufRowInputFormat.h +++ b/src/Processors/Formats/Impl/ProtobufRowInputFormat.h @@ -5,14 +5,14 @@ #endif #if USE_PROTOBUF -# include -# include # include namespace DB { class Block; class FormatSchemaInfo; +class ProtobufReader; +class ProtobufSerializer; /** Stream designed to deserialize data from the google protobuf format. @@ -29,18 +29,19 @@ class FormatSchemaInfo; class ProtobufRowInputFormat : public IRowInputFormat { public: - ProtobufRowInputFormat(ReadBuffer & in_, const Block & header_, Params params_, const FormatSchemaInfo & info_, const bool use_length_delimiters_); + ProtobufRowInputFormat(ReadBuffer & in_, const Block & header_, const Params & params_, const FormatSchemaInfo & schema_info_, bool with_length_delimiter_); ~ProtobufRowInputFormat() override; String getName() const override { return "ProtobufRowInputFormat"; } - bool readRow(MutableColumns & columns, RowReadExtension & extra) override; + bool readRow(MutableColumns & columns, RowReadExtension &) override; bool allowSyncAfterError() const override; void syncAfterError() override; private: - DataTypes data_types; - ProtobufReader reader; + std::unique_ptr reader; + std::vector missing_column_indices; + std::unique_ptr serializer; }; } diff --git a/src/Processors/Formats/Impl/ProtobufRowOutputFormat.cpp b/src/Processors/Formats/Impl/ProtobufRowOutputFormat.cpp index 3c885e80e31..d3b9a0124c1 100644 --- a/src/Processors/Formats/Impl/ProtobufRowOutputFormat.cpp +++ b/src/Processors/Formats/Impl/ProtobufRowOutputFormat.cpp @@ -1,13 +1,13 @@ -#include #include "ProtobufRowOutputFormat.h" #if USE_PROTOBUF - -#include -#include -#include -#include -#include +# include +# include +# include +# include +# include +# include +# include namespace DB @@ -20,58 +20,55 @@ namespace ErrorCodes ProtobufRowOutputFormat::ProtobufRowOutputFormat( WriteBuffer & out_, - const Block & header, + const Block & header_, const RowOutputFormatParams & params_, - const FormatSchemaInfo & format_schema, - const FormatSettings & settings) - : IRowOutputFormat(header, out_, params_) - , data_types(header.getDataTypes()) - , writer(out, - ProtobufSchemas::instance().getMessageTypeForFormatSchema(format_schema), - header.getNames(), settings.protobuf.write_row_delimiters) - , allow_only_one_row( - !settings.protobuf.write_row_delimiters - && !settings.protobuf.allow_many_rows_no_delimiters) + const FormatSchemaInfo & schema_info_, + const FormatSettings & settings_, + bool with_length_delimiter_) + : IRowOutputFormat(header_, out_, params_) + , writer(std::make_unique(out)) + , serializer(ProtobufSerializer::create( + header_.getNames(), + header_.getDataTypes(), + *ProtobufSchemas::instance().getMessageTypeForFormatSchema(schema_info_), + with_length_delimiter_, + *writer)) + , allow_multiple_rows(with_length_delimiter_ || settings_.protobuf.allow_multiple_rows_without_delimiter) { - value_indices.resize(header.columns()); } void ProtobufRowOutputFormat::write(const Columns & columns, size_t row_num) { - if (allow_only_one_row && !first_row) - { - throw Exception("The ProtobufSingle format can't be used to write multiple rows because this format doesn't have any row delimiter.", ErrorCodes::NO_ROW_DELIMITER); - } + if (!allow_multiple_rows && !first_row) + throw Exception( + "The ProtobufSingle format can't be used to write multiple rows because this format doesn't have any row delimiter.", + ErrorCodes::NO_ROW_DELIMITER); - writer.startMessage(); - std::fill(value_indices.begin(), value_indices.end(), 0); - size_t column_index; - while (writer.writeField(column_index)) - data_types[column_index]->serializeProtobuf( - *columns[column_index], row_num, writer, value_indices[column_index]); - writer.endMessage(); + if (!row_num) + serializer->setColumns(columns.data(), columns.size()); + + serializer->writeRow(row_num); } void registerOutputFormatProcessorProtobuf(FormatFactory & factory) { - for (bool write_row_delimiters : {false, true}) + for (bool with_length_delimiter : {false, true}) { factory.registerOutputFormatProcessor( - write_row_delimiters ? "Protobuf" : "ProtobufSingle", - [write_row_delimiters](WriteBuffer & buf, + with_length_delimiter ? "Protobuf" : "ProtobufSingle", + [with_length_delimiter](WriteBuffer & buf, const Block & header, const RowOutputFormatParams & params, - const FormatSettings & _settings) + const FormatSettings & settings) { - FormatSettings settings = _settings; - settings.protobuf.write_row_delimiters = write_row_delimiters; return std::make_shared( buf, header, params, FormatSchemaInfo(settings.schema.format_schema, "Protobuf", true, settings.schema.is_server, settings.schema.format_schema_path), - settings); + settings, + with_length_delimiter); }); } } diff --git a/src/Processors/Formats/Impl/ProtobufRowOutputFormat.h b/src/Processors/Formats/Impl/ProtobufRowOutputFormat.h index 847f7607ff5..5f82950e891 100644 --- a/src/Processors/Formats/Impl/ProtobufRowOutputFormat.h +++ b/src/Processors/Formats/Impl/ProtobufRowOutputFormat.h @@ -8,21 +8,16 @@ # include # include # include -# include # include -namespace google -{ -namespace protobuf -{ - class Message; -} -} - - namespace DB { +class ProtobufWriter; +class ProtobufSerializer; +class FormatSchemaInfo; +struct FormatSettings; + /** Stream designed to serialize data in the google protobuf format. * Each row is written as a separated message. * @@ -38,10 +33,11 @@ class ProtobufRowOutputFormat : public IRowOutputFormat public: ProtobufRowOutputFormat( WriteBuffer & out_, - const Block & header, + const Block & header_, const RowOutputFormatParams & params_, - const FormatSchemaInfo & format_schema, - const FormatSettings & settings); + const FormatSchemaInfo & schema_info_, + const FormatSettings & settings_, + bool with_length_delimiter_); String getName() const override { return "ProtobufRowOutputFormat"; } @@ -50,10 +46,9 @@ public: std::string getContentType() const override { return "application/octet-stream"; } private: - DataTypes data_types; - ProtobufWriter writer; - std::vector value_indices; - const bool allow_only_one_row; + std::unique_ptr writer; + std::unique_ptr serializer; + const bool allow_multiple_rows; }; } diff --git a/src/Storages/Kafka/KafkaBlockOutputStream.cpp b/src/Storages/Kafka/KafkaBlockOutputStream.cpp index cfbb7ad2523..2cb0fd98c71 100644 --- a/src/Storages/Kafka/KafkaBlockOutputStream.cpp +++ b/src/Storages/Kafka/KafkaBlockOutputStream.cpp @@ -26,7 +26,7 @@ void KafkaBlockOutputStream::writePrefix() buffer = storage.createWriteBuffer(getHeader()); auto format_settings = getFormatSettings(*context); - format_settings.protobuf.allow_many_rows_no_delimiters = true; + format_settings.protobuf.allow_multiple_rows_without_delimiter = true; child = FormatFactory::instance().getOutputStream(storage.getFormatName(), *buffer, getHeader(), *context, diff --git a/src/Storages/RabbitMQ/RabbitMQBlockOutputStream.cpp b/src/Storages/RabbitMQ/RabbitMQBlockOutputStream.cpp index d239586bb65..a987fff3c64 100644 --- a/src/Storages/RabbitMQ/RabbitMQBlockOutputStream.cpp +++ b/src/Storages/RabbitMQ/RabbitMQBlockOutputStream.cpp @@ -34,7 +34,7 @@ void RabbitMQBlockOutputStream::writePrefix() buffer->activateWriting(); auto format_settings = getFormatSettings(context); - format_settings.protobuf.allow_many_rows_no_delimiters = true; + format_settings.protobuf.allow_multiple_rows_without_delimiter = true; child = FormatFactory::instance().getOutputStream(storage.getFormatName(), *buffer, getHeader(), context, diff --git a/tests/queries/0_stateless/00825_protobuf_format_array_3dim.proto b/tests/queries/0_stateless/00825_protobuf_format_array_3dim.proto new file mode 100644 index 00000000000..8673924c929 --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_array_3dim.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +message ABC +{ + message nested + { + message nested + { + repeated int32 c = 1; + } + repeated nested b = 1; + } + repeated nested a = 1; +} \ No newline at end of file diff --git a/tests/queries/0_stateless/00825_protobuf_format_array_3dim.reference b/tests/queries/0_stateless/00825_protobuf_format_array_3dim.reference new file mode 100644 index 00000000000..69e7d5e1da8 --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_array_3dim.reference @@ -0,0 +1,52 @@ +[[],[[]],[[1]],[[2,3],[4]]] +[[[5,6,7]],[[8,9,10]]] + +Binary representation: +00000000 1a 0a 00 0a 02 0a 00 0a 05 0a 03 0a 01 01 0a 0b |................| +00000010 0a 04 0a 02 02 03 0a 03 0a 01 04 12 0a 07 0a 05 |................| +00000020 0a 03 05 06 07 0a 07 0a 05 0a 03 08 09 0a |..............| +0000002e + +MESSAGE #1 AT 0x00000001 +a { +} +a { + b { + } +} +a { + b { + c: 1 + } +} +a { + b { + c: 2 + c: 3 + } + b { + c: 4 + } +} +MESSAGE #2 AT 0x0000001C +a { + b { + c: 5 + c: 6 + c: 7 + } +} +a { + b { + c: 8 + c: 9 + c: 10 + } +} + +Binary representation is as expected + +[[],[[]],[[1]],[[2,3],[4]]] +[[[5,6,7]],[[8,9,10]]] +[[],[[]],[[1]],[[2,3],[4]]] +[[[5,6,7]],[[8,9,10]]] diff --git a/tests/queries/0_stateless/00825_protobuf_format_array_3dim.sh b/tests/queries/0_stateless/00825_protobuf_format_array_3dim.sh new file mode 100755 index 00000000000..903217ca939 --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_array_3dim.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +set -eo pipefail + +# Run the client. +$CLICKHOUSE_CLIENT --multiquery <<'EOF' +DROP TABLE IF EXISTS array_3dim_protobuf_00825; + +CREATE TABLE array_3dim_protobuf_00825 +( + `a_b_c` Array(Array(Array(Int32))) +) ENGINE = MergeTree ORDER BY tuple(); + +INSERT INTO array_3dim_protobuf_00825 VALUES ([[], [[]], [[1]], [[2,3],[4]]]), ([[[5, 6, 7]], [[8, 9, 10]]]); + +SELECT * FROM array_3dim_protobuf_00825; +EOF + +BINARY_FILE_PATH=$(mktemp "$CURDIR/00825_protobuf_format_array_3dim.XXXXXX.binary") +$CLICKHOUSE_CLIENT --query "SELECT * FROM array_3dim_protobuf_00825 FORMAT Protobuf SETTINGS format_schema = '$CURDIR/00825_protobuf_format_array_3dim:ABC'" > "$BINARY_FILE_PATH" + +# Check the output in the protobuf format +echo +$CURDIR/helpers/protobuf_length_delimited_encoder.py --decode_and_check --format_schema "$CURDIR/00825_protobuf_format_array_3dim:ABC" --input "$BINARY_FILE_PATH" + +# Check the input in the protobuf format (now the table contains the same data twice). +echo +$CLICKHOUSE_CLIENT --query "INSERT INTO array_3dim_protobuf_00825 FORMAT Protobuf SETTINGS format_schema='$CURDIR/00825_protobuf_format_array_3dim:ABC'" < "$BINARY_FILE_PATH" +$CLICKHOUSE_CLIENT --query "SELECT * FROM array_3dim_protobuf_00825" + +rm "$BINARY_FILE_PATH" diff --git a/tests/queries/0_stateless/00825_protobuf_format_array_of_arrays.proto b/tests/queries/0_stateless/00825_protobuf_format_array_of_arrays.proto new file mode 100644 index 00000000000..8f84164da2a --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_array_of_arrays.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +message AA { + message nested_array { + repeated double c = 2; + } + string a = 1; + repeated nested_array b = 2; +} \ No newline at end of file diff --git a/tests/queries/0_stateless/00825_protobuf_format_array_of_arrays.reference b/tests/queries/0_stateless/00825_protobuf_format_array_of_arrays.reference new file mode 100644 index 00000000000..5ea6780a3ba --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_array_of_arrays.reference @@ -0,0 +1,41 @@ +one [[1,2,3],[0.5,0.25],[],[4,5],[0.125,0.0625],[6]] + +Binary representation: +00000000 6b 0a 03 6f 6e 65 12 1a 12 18 00 00 00 00 00 00 |k..one..........| +00000010 f0 3f 00 00 00 00 00 00 00 40 00 00 00 00 00 00 |.?.......@......| +00000020 08 40 12 12 12 10 00 00 00 00 00 00 e0 3f 00 00 |.@...........?..| +00000030 00 00 00 00 d0 3f 12 00 12 12 12 10 00 00 00 00 |.....?..........| +00000040 00 00 10 40 00 00 00 00 00 00 14 40 12 12 12 10 |...@.......@....| +00000050 00 00 00 00 00 00 c0 3f 00 00 00 00 00 00 b0 3f |.......?.......?| +00000060 12 0a 12 08 00 00 00 00 00 00 18 40 |...........@| +0000006c + +MESSAGE #1 AT 0x00000001 +a: "one" +b { + c: 1 + c: 2 + c: 3 +} +b { + c: 0.5 + c: 0.25 +} +b { +} +b { + c: 4 + c: 5 +} +b { + c: 0.125 + c: 0.0625 +} +b { + c: 6 +} + +Binary representation is as expected + +one [[1,2,3],[0.5,0.25],[],[4,5],[0.125,0.0625],[6]] +one [[1,2,3],[0.5,0.25],[],[4,5],[0.125,0.0625],[6]] diff --git a/tests/queries/0_stateless/00825_protobuf_format_array_of_arrays.sh b/tests/queries/0_stateless/00825_protobuf_format_array_of_arrays.sh new file mode 100755 index 00000000000..0b386723091 --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_array_of_arrays.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +# https://github.com/ClickHouse/ClickHouse/issues/9069 + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +set -eo pipefail + +# Run the client. +$CLICKHOUSE_CLIENT --multiquery <<'EOF' +CREATE TABLE array_of_arrays_protobuf_00825 +( + `a` String, + `b` Nested ( + `c` Array(Float64) + ) +) ENGINE = MergeTree ORDER BY tuple(); + +INSERT INTO array_of_arrays_protobuf_00825 VALUES ('one', [[1,2,3],[0.5,0.25],[],[4,5],[0.125,0.0625],[6]]); + +SELECT * FROM array_of_arrays_protobuf_00825; +EOF + +BINARY_FILE_PATH=$(mktemp "$CURDIR/00825_protobuf_format_array_of_arrays.XXXXXX.binary") +$CLICKHOUSE_CLIENT --query "SELECT * FROM array_of_arrays_protobuf_00825 FORMAT Protobuf SETTINGS format_schema = '$CURDIR/00825_protobuf_format_array_of_arrays:AA'" > "$BINARY_FILE_PATH" + +# Check the output in the protobuf format +echo +$CURDIR/helpers/protobuf_length_delimited_encoder.py --decode_and_check --format_schema "$CURDIR/00825_protobuf_format_array_of_arrays:AA" --input "$BINARY_FILE_PATH" + +# Check the input in the protobuf format (now the table contains the same data twice). +echo +$CLICKHOUSE_CLIENT --query "INSERT INTO array_of_arrays_protobuf_00825 FORMAT Protobuf SETTINGS format_schema='$CURDIR/00825_protobuf_format_array_of_arrays:AA'" < "$BINARY_FILE_PATH" +$CLICKHOUSE_CLIENT --query "SELECT * FROM array_of_arrays_protobuf_00825" + +rm "$BINARY_FILE_PATH" diff --git a/tests/queries/0_stateless/00825_protobuf_format_enum_mapping.proto b/tests/queries/0_stateless/00825_protobuf_format_enum_mapping.proto new file mode 100644 index 00000000000..ba558dbbadb --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_enum_mapping.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +message Message +{ + enum Enum + { + FIRST = 0; + SECOND = 1; + TEN = 10; + HUNDRED = 100; + }; + Enum x = 1; +}; \ No newline at end of file diff --git a/tests/queries/0_stateless/00825_protobuf_format_enum_mapping.reference b/tests/queries/0_stateless/00825_protobuf_format_enum_mapping.reference new file mode 100644 index 00000000000..ef8059bac28 --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_enum_mapping.reference @@ -0,0 +1,31 @@ +Second +Third +First +First +Second + +Binary representation: +00000000 02 08 01 02 08 64 00 00 02 08 01 |.....d.....| +0000000b + +MESSAGE #1 AT 0x00000001 +x: SECOND +MESSAGE #2 AT 0x00000004 +x: HUNDRED +MESSAGE #3 AT 0x00000007 +MESSAGE #4 AT 0x00000008 +MESSAGE #5 AT 0x00000009 +x: SECOND + +Binary representation is as expected + +Second +Third +First +First +Second +Second +Third +First +First +Second diff --git a/tests/queries/0_stateless/00825_protobuf_format_enum_mapping.sh b/tests/queries/0_stateless/00825_protobuf_format_enum_mapping.sh new file mode 100755 index 00000000000..cbb387a62a5 --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_enum_mapping.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash + +# https://github.com/ClickHouse/ClickHouse/issues/7438 + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +set -eo pipefail + +# Run the client. +$CLICKHOUSE_CLIENT --multiquery <<'EOF' +DROP TABLE IF EXISTS enum_mapping_protobuf_00825; + +CREATE TABLE enum_mapping_protobuf_00825 +( + x Enum16('First'=-100, 'Second'=0, 'Third'=100) +) ENGINE = MergeTree ORDER BY tuple(); + +INSERT INTO enum_mapping_protobuf_00825 VALUES ('Second'), ('Third'), ('First'), ('First'), ('Second'); + +SELECT * FROM enum_mapping_protobuf_00825; +EOF + +BINARY_FILE_PATH=$(mktemp "$CURDIR/00825_protobuf_format_enum_mapping.XXXXXX.binary") +$CLICKHOUSE_CLIENT --query "SELECT * FROM enum_mapping_protobuf_00825 FORMAT Protobuf SETTINGS format_schema = '$CURDIR/00825_protobuf_format_enum_mapping:Message'" > "$BINARY_FILE_PATH" + +# Check the output in the protobuf format +echo +$CURDIR/helpers/protobuf_length_delimited_encoder.py --decode_and_check --format_schema "$CURDIR/00825_protobuf_format_enum_mapping:Message" --input "$BINARY_FILE_PATH" + +# Check the input in the protobuf format (now the table contains the same data twice). +echo +$CLICKHOUSE_CLIENT --query "INSERT INTO enum_mapping_protobuf_00825 FORMAT Protobuf SETTINGS format_schema='$CURDIR/00825_protobuf_format_enum_mapping:Message'" < "$BINARY_FILE_PATH" +$CLICKHOUSE_CLIENT --query "SELECT * FROM enum_mapping_protobuf_00825" + +rm "$BINARY_FILE_PATH" diff --git a/tests/queries/0_stateless/00825_protobuf_format_map.proto b/tests/queries/0_stateless/00825_protobuf_format_map.proto new file mode 100644 index 00000000000..561b409b733 --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_map.proto @@ -0,0 +1,5 @@ +syntax = "proto3"; + +message Message { + map a = 1; +}; diff --git a/tests/queries/0_stateless/00825_protobuf_format_map.reference b/tests/queries/0_stateless/00825_protobuf_format_map.reference new file mode 100644 index 00000000000..e3f17cb1095 --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_map.reference @@ -0,0 +1,19 @@ +{'x':5,'y':7} +{'z':11} +{'temp':0} +{'':0} + +Binary representation: +00000000 0e 0a 05 0a 01 78 10 05 0a 05 0a 01 79 10 07 07 |.....x......y...| +00000010 0a 05 0a 01 7a 10 0b 0a 0a 08 0a 04 74 65 6d 70 |....z.......temp| +00000020 10 00 06 0a 04 0a 00 10 00 |.........| +00000029 + +{'x':5,'y':7} +{'z':11} +{'temp':0} +{'':0} +{'x':5,'y':7} +{'z':11} +{'temp':0} +{'':0} diff --git a/tests/queries/0_stateless/00825_protobuf_format_map.sh b/tests/queries/0_stateless/00825_protobuf_format_map.sh new file mode 100755 index 00000000000..5df25c41750 --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_map.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash + +# https://github.com/ClickHouse/ClickHouse/issues/6497 + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +set -eo pipefail + +# Run the client. +$CLICKHOUSE_CLIENT --multiquery <<'EOF' +SET allow_experimental_map_type = 1; + +DROP TABLE IF EXISTS map_00825; + +CREATE TABLE map_00825 +( + a Map(String, UInt32) +) ENGINE = MergeTree ORDER BY tuple(); + +INSERT INTO map_00825 VALUES ({'x':5, 'y':7}), ({'z':11}), ({'temp':0}), ({'':0}); + +SELECT * FROM map_00825; +EOF + +BINARY_FILE_PATH=$(mktemp "$CURDIR/00825_protobuf_format_map.XXXXXX.binary") +$CLICKHOUSE_CLIENT --query "SELECT * FROM map_00825 FORMAT Protobuf SETTINGS format_schema = '$CURDIR/00825_protobuf_format_map:Message'" > "$BINARY_FILE_PATH" + +# Check the output in the protobuf format +echo +echo "Binary representation:" +hexdump -C $BINARY_FILE_PATH + +# Check the input in the protobuf format (now the table contains the same data twice). +echo +$CLICKHOUSE_CLIENT --query "INSERT INTO map_00825 FORMAT Protobuf SETTINGS format_schema='$CURDIR/00825_protobuf_format_map:Message'" < "$BINARY_FILE_PATH" +$CLICKHOUSE_CLIENT --query "SELECT * FROM map_00825" + +rm "$BINARY_FILE_PATH" diff --git a/tests/queries/0_stateless/00825_protobuf_format_nested_optional.proto b/tests/queries/0_stateless/00825_protobuf_format_nested_optional.proto new file mode 100644 index 00000000000..052741f504b --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_nested_optional.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +message Repeated { + string foo = 1; + int64 bar = 2; +} + +message Message { + repeated Repeated messages = 1; +}; \ No newline at end of file diff --git a/tests/queries/0_stateless/00825_protobuf_format_nested_optional.reference b/tests/queries/0_stateless/00825_protobuf_format_nested_optional.reference new file mode 100644 index 00000000000..6cdd56a5b7f --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_nested_optional.reference @@ -0,0 +1,25 @@ +['1'] [0] +['1',''] [0,1] + +Binary representation: +00000000 05 0a 03 0a 01 31 09 0a 03 0a 01 31 0a 02 10 01 |.....1.....1....| +00000010 + +MESSAGE #1 AT 0x00000001 +messages { + foo: "1" +} +MESSAGE #2 AT 0x00000007 +messages { + foo: "1" +} +messages { + bar: 1 +} + +Binary representation is as expected + +['1'] [0] +['1',''] [0,1] +['1'] [0] +['1',''] [0,1] diff --git a/tests/queries/0_stateless/00825_protobuf_format_nested_optional.sh b/tests/queries/0_stateless/00825_protobuf_format_nested_optional.sh new file mode 100755 index 00000000000..58ded92f2c1 --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_nested_optional.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +# https://github.com/ClickHouse/ClickHouse/issues/6497 + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +set -eo pipefail + +# Run the client. +$CLICKHOUSE_CLIENT --multiquery <<'EOF' +DROP TABLE IF EXISTS nested_optional_protobuf_00825; + +CREATE TABLE nested_optional_protobuf_00825 +( + messages Nested + ( + foo String, + bar Int64 + ) +) ENGINE = MergeTree ORDER BY tuple(); + +INSERT INTO nested_optional_protobuf_00825 VALUES (['1'], [0]), (['1', ''], [0, 1]); + +SELECT * FROM nested_optional_protobuf_00825; +EOF + +BINARY_FILE_PATH=$(mktemp "$CURDIR/00825_protobuf_format_nested_optional.XXXXXX.binary") +$CLICKHOUSE_CLIENT --query "SELECT * FROM nested_optional_protobuf_00825 FORMAT Protobuf SETTINGS format_schema = '$CURDIR/00825_protobuf_format_nested_optional:Message'" > "$BINARY_FILE_PATH" + +# Check the output in the protobuf format +echo +$CURDIR/helpers/protobuf_length_delimited_encoder.py --decode_and_check --format_schema "$CURDIR/00825_protobuf_format_nested_optional:Message" --input "$BINARY_FILE_PATH" + +# Check the input in the protobuf format (now the table contains the same data twice). +echo +$CLICKHOUSE_CLIENT --query "INSERT INTO nested_optional_protobuf_00825 FORMAT Protobuf SETTINGS format_schema='$CURDIR/00825_protobuf_format_nested_optional:Message'" < "$BINARY_FILE_PATH" +$CLICKHOUSE_CLIENT --query "SELECT * FROM nested_optional_protobuf_00825" + +rm "$BINARY_FILE_PATH" diff --git a/tests/queries/0_stateless/00825_protobuf_format_table_default.proto b/tests/queries/0_stateless/00825_protobuf_format_table_default.proto new file mode 100644 index 00000000000..08e6049ffe0 --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_table_default.proto @@ -0,0 +1,6 @@ +syntax = "proto3"; + +message Message { + sint32 x = 1; + sint32 z = 2; +}; \ No newline at end of file diff --git a/tests/queries/0_stateless/00825_protobuf_format_table_default.reference b/tests/queries/0_stateless/00825_protobuf_format_table_default.reference new file mode 100644 index 00000000000..5472f3bfa14 --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_table_default.reference @@ -0,0 +1,37 @@ +0 0 0 +2 4 8 +3 9 27 +5 25 125 +101 102 103 + +Binary representation: +00000000 00 04 08 04 10 10 04 08 06 10 36 05 08 0a 10 fa |..........6.....| +00000010 01 06 08 ca 01 10 ce 01 |........| +00000018 + +MESSAGE #1 AT 0x00000001 +MESSAGE #2 AT 0x00000002 +x: 2 +z: 8 +MESSAGE #3 AT 0x00000007 +x: 3 +z: 27 +MESSAGE #4 AT 0x0000000C +x: 5 +z: 125 +MESSAGE #5 AT 0x00000012 +x: 101 +z: 103 + +Binary representation is as expected + +0 0 0 +0 0 0 +2 4 8 +2 4 8 +3 9 27 +3 9 27 +5 25 125 +5 25 125 +101 102 103 +101 10201 103 diff --git a/tests/queries/0_stateless/00825_protobuf_format_table_default.sh b/tests/queries/0_stateless/00825_protobuf_format_table_default.sh new file mode 100755 index 00000000000..97f7769269a --- /dev/null +++ b/tests/queries/0_stateless/00825_protobuf_format_table_default.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +set -eo pipefail + +# Run the client. +$CLICKHOUSE_CLIENT --multiquery <<'EOF' +DROP TABLE IF EXISTS table_default_protobuf_00825; + +CREATE TABLE table_default_protobuf_00825 +( + x Int64, + y Int64 DEFAULT x * x, + z Int64 DEFAULT x * x * x +) ENGINE = MergeTree ORDER BY tuple(); + +INSERT INTO table_default_protobuf_00825 (x) VALUES (0), (2), (3), (5); +INSERT INTO table_default_protobuf_00825 VALUES (101, 102, 103); + +SELECT * FROM table_default_protobuf_00825 ORDER BY x,y,z; +EOF + +BINARY_FILE_PATH=$(mktemp "$CURDIR/00825_protobuf_format_table_default.XXXXXX.binary") +$CLICKHOUSE_CLIENT --query "SELECT * FROM table_default_protobuf_00825 ORDER BY x,y,z FORMAT Protobuf SETTINGS format_schema = '$CURDIR/00825_protobuf_format_table_default:Message'" > "$BINARY_FILE_PATH" + +# Check the output in the protobuf format +echo +$CURDIR/helpers/protobuf_length_delimited_encoder.py --decode_and_check --format_schema "$CURDIR/00825_protobuf_format_table_default:Message" --input "$BINARY_FILE_PATH" + +# Check the input in the protobuf format (now the table contains the same data twice). +echo +$CLICKHOUSE_CLIENT --query "INSERT INTO table_default_protobuf_00825 FORMAT Protobuf SETTINGS format_schema='$CURDIR/00825_protobuf_format_table_default:Message'" < "$BINARY_FILE_PATH" +$CLICKHOUSE_CLIENT --query "SELECT * FROM table_default_protobuf_00825 ORDER BY x,y,z" + +rm "$BINARY_FILE_PATH" diff --git a/tests/queries/0_stateless/helpers/protobuf_length_delimited_encoder.py b/tests/queries/0_stateless/helpers/protobuf_length_delimited_encoder.py new file mode 100755 index 00000000000..3ed42f1c820 --- /dev/null +++ b/tests/queries/0_stateless/helpers/protobuf_length_delimited_encoder.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 + +# The protobuf compiler protoc doesn't support encoding or decoding length-delimited protobuf message. +# To do that this script has been written. + +import argparse +import os.path +import struct +import subprocess +import sys +import tempfile + +def read_varint(input): + res = 0 + shift = 0 + while True: + c = input.read(1) + if len(c) == 0: + return None + b = c[0] + if b < 0x80: + res += b << shift + break + b -= 0x80 + res += b << shift + shift = shift << 7 + return res + +def write_varint(output, value): + while True: + if value < 0x80: + b = value + output.write(b.to_bytes(1, byteorder='little')) + break + b = (value & 0x7F) + 0x80 + output.write(b.to_bytes(1, byteorder='little')) + value = value >> 7 + +def write_hexdump(output, data): + with subprocess.Popen(["hexdump", "-C"], stdin=subprocess.PIPE, stdout=output, shell=False) as proc: + proc.communicate(data) + if proc.returncode != 0: + raise RuntimeError("hexdump returned code " + str(proc.returncode)) + output.flush() + +class FormatSchemaSplitted: + def __init__(self, format_schema): + self.format_schema = format_schema + splitted = self.format_schema.split(':') + if len(splitted) < 2: + raise RuntimeError('The format schema must have the format "schemafile:MessageType"') + path = splitted[0] + self.schemadir = os.path.dirname(path) + self.schemaname = os.path.basename(path) + if not self.schemaname.endswith(".proto"): + self.schemaname = self.schemaname + ".proto" + self.message_type = splitted[1] + +def decode(input, output, format_schema): + if not type(format_schema) is FormatSchemaSplitted: + format_schema = FormatSchemaSplitted(format_schema) + msgindex = 1 + while True: + sz = read_varint(input) + if sz is None: + break + output.write("MESSAGE #{msgindex} AT 0x{msgoffset:08X}\n".format(msgindex=msgindex, msgoffset=input.tell()).encode()) + output.flush() + msg = input.read(sz) + if len(msg) < sz: + raise EOFError('Unexpected end of file') + with subprocess.Popen(["protoc", + "--decode", format_schema.message_type, format_schema.schemaname], + cwd=format_schema.schemadir, + stdin=subprocess.PIPE, + stdout=output, + shell=False) as proc: + proc.communicate(msg) + if proc.returncode != 0: + raise RuntimeError("protoc returned code " + str(proc.returncode)) + output.flush() + msgindex = msgindex + 1 + +def encode(input, output, format_schema): + if not type(format_schema) is FormatSchemaSplitted: + format_schema = FormatSchemaSplitted(format_schema) + line_offset = input.tell() + line = input.readline() + while True: + if len(line) == 0: + break + if not line.startswith(b"MESSAGE #"): + raise RuntimeError("The line at 0x{line_offset:08X} must start with the text 'MESSAGE #'".format(line_offset=line_offset)) + msg = b"" + while True: + line_offset = input.tell() + line = input.readline() + if line.startswith(b"MESSAGE #") or len(line) == 0: + break + msg += line + with subprocess.Popen(["protoc", + "--encode", format_schema.message_type, format_schema.schemaname], + cwd=format_schema.schemadir, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + shell=False) as proc: + msgbin = proc.communicate(msg)[0] + if proc.returncode != 0: + raise RuntimeError("protoc returned code " + str(proc.returncode)) + write_varint(output, len(msgbin)) + output.write(msgbin) + output.flush() + +def decode_and_check(input, output, format_schema): + input_data = input.read() + output.write(b"Binary representation:\n") + output.flush() + write_hexdump(output, input_data) + output.write(b"\n") + output.flush() + + with tempfile.TemporaryFile() as tmp_input, tempfile.TemporaryFile() as tmp_decoded, tempfile.TemporaryFile() as tmp_encoded: + tmp_input.write(input_data) + tmp_input.flush() + tmp_input.seek(0) + decode(tmp_input, tmp_decoded, format_schema) + tmp_decoded.seek(0) + decoded_text = tmp_decoded.read() + output.write(decoded_text) + output.flush() + tmp_decoded.seek(0) + encode(tmp_decoded, tmp_encoded, format_schema) + tmp_encoded.seek(0) + encoded_data = tmp_encoded.read() + + if encoded_data == input_data: + output.write(b"\nBinary representation is as expected\n") + output.flush() + else: + output.write(b"\nBinary representation differs from the expected one (listed below):\n") + output.flush() + write_hexdump(output, encoded_data) + sys.exit(1) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Encodes or decodes length-delimited protobuf messages.') + parser.add_argument('--input', help='The input file, the standard input will be used if not specified.') + parser.add_argument('--output', help='The output file, the standard output will be used if not specified') + parser.add_argument('--format_schema', required=True, help='Format schema in the format "schemafile:MessageType"') + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('--encode', action='store_true', help='Specify to encode length-delimited messages.' + 'The utility will read text-format messages of the given type from the input and write it in binary to the output.') + group.add_argument('--decode', action='store_true', help='Specify to decode length-delimited messages.' + 'The utility will read messages in binary from the input and write text-format messages to the output.') + group.add_argument('--decode_and_check', action='store_true', help='The same as --decode, and the utility will then encode ' + ' the decoded data back to the binary form to check that the result of that encoding is the same as the input was.') + args = parser.parse_args() + + custom_input_file = None + custom_output_file = None + try: + if args.input: + custom_input_file = open(args.input, "rb") + if args.output: + custom_output_file = open(args.output, "wb") + input = custom_input_file if custom_input_file else sys.stdin.buffer + output = custom_output_file if custom_output_file else sys.stdout.buffer + + if args.encode: + encode(input, output, args.format_schema) + elif args.decode: + decode(input, output, args.format_schema) + elif args.decode_and_check: + decode_and_check(input, output, args.format_schema) + + finally: + if custom_input_file: + custom_input_file.close() + if custom_output_file: + custom_output_file.close() diff --git a/tests/queries/skip_list.json b/tests/queries/skip_list.json index ee25bee6a0a..0e470e14916 100644 --- a/tests/queries/skip_list.json +++ b/tests/queries/skip_list.json @@ -131,6 +131,12 @@ "00763_create_query_as_table_engine_bug", "00765_sql_compatibility_aliases", "00825_protobuf_format_input", + "00825_protobuf_format_nested_optional", + "00825_protobuf_format_array_3dim", + "00825_protobuf_format_map", + "00825_protobuf_format_array_of_arrays", + "00825_protobuf_format_table_default", + "00825_protobuf_format_enum_mapping", "00826_cross_to_inner_join", "00834_not_between", "00909_kill_not_initialized_query",