Merge pull request #6070 from bopohaa/fix-protobuf-length-delimited-message-limit

Add verification of the length of the protobuf message
This commit is contained in:
alexey-milovidov 2019-07-20 00:31:36 +03:00 committed by GitHub
commit d0eb20f4b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 15 deletions

View File

@ -41,11 +41,12 @@ namespace
constexpr UInt64 END_OF_GROUP = static_cast<UInt64>(-2);
Int64 decodeZigZag(UInt64 n) { return static_cast<Int64>((n >> 1) ^ (~(n & 1) + 1)); }
}
[[noreturn]] void unknownFormat()
{
throw Exception("Protobuf messages are corrupted or don't match the provided schema. Please note that Protobuf stream is length-delimited: every message is prefixed by its length in varint.", ErrorCodes::UNKNOWN_PROTOBUF_FORMAT);
}
[[noreturn]] void ProtobufReader::SimpleReader::throwUnknownFormat() const
{
throw Exception("Protobuf messages are corrupted or don't match the provided schema. Please note that Protobuf stream is length-delimited: every message is prefixed by its length in varint.", ErrorCodes::UNKNOWN_PROTOBUF_FORMAT);
}
@ -67,7 +68,10 @@ bool ProtobufReader::SimpleReader::startMessage()
if (unlikely(in.eof()))
return false;
size_t size_of_message = readVarint();
if (size_of_message == 0)
throwUnknownFormat();
current_message_end = cursor + size_of_message;
root_message_end = current_message_end;
}
else
{
@ -91,7 +95,7 @@ void ProtobufReader::SimpleReader::endMessage()
else if (unlikely(cursor > current_message_end))
{
if (!parent_message_ends.empty())
unknownFormat();
throwUnknownFormat();
moveCursorBackward(cursor - current_message_end);
}
current_message_end = REACHED_END;
@ -141,7 +145,7 @@ bool ProtobufReader::SimpleReader::readFieldNumber(UInt32 & field_number)
UInt64 varint = readVarint();
if (unlikely(varint & (static_cast<UInt64>(0xFFFFFFFF) << 32)))
unknownFormat();
throwUnknownFormat();
UInt32 key = static_cast<UInt32>(varint);
field_number = (key >> 3);
WireType wire_type = static_cast<WireType>(key & 0x07);
@ -171,7 +175,7 @@ bool ProtobufReader::SimpleReader::readFieldNumber(UInt32 & field_number)
case GROUP_END:
{
if (current_message_end != END_OF_GROUP)
unknownFormat();
throwUnknownFormat();
current_message_end = REACHED_END;
return false;
}
@ -181,7 +185,7 @@ bool ProtobufReader::SimpleReader::readFieldNumber(UInt32 & field_number)
return true;
}
}
unknownFormat();
throwUnknownFormat();
__builtin_unreachable();
}
@ -257,7 +261,7 @@ void ProtobufReader::SimpleReader::ignore(UInt64 num_bytes)
void ProtobufReader::SimpleReader::moveCursorBackward(UInt64 num_bytes)
{
if (in.offset() < num_bytes)
unknownFormat();
throwUnknownFormat();
in.position() -= num_bytes;
cursor -= num_bytes;
}
@ -294,7 +298,7 @@ UInt64 ProtobufReader::SimpleReader::continueReadingVarint(UInt64 first_byte)
PROTOBUF_READER_READ_VARINT_BYTE(10)
#undef PROTOBUF_READER_READ_VARINT_BYTE
unknownFormat();
throwUnknownFormat();
__builtin_unreachable();
}
@ -327,7 +331,7 @@ void ProtobufReader::SimpleReader::ignoreVarint()
PROTOBUF_READER_IGNORE_VARINT_BYTE(10)
#undef PROTOBUF_READER_IGNORE_VARINT_BYTE
unknownFormat();
throwUnknownFormat();
}
void ProtobufReader::SimpleReader::ignoreGroup()
@ -371,11 +375,10 @@ void ProtobufReader::SimpleReader::ignoreGroup()
break;
}
}
unknownFormat();
throwUnknownFormat();
}
}
// Implementation for a converter from any protobuf field type to any DB data type.
class ProtobufReader::ConverterBaseImpl : public ProtobufReader::IConverter
{

View File

@ -97,10 +97,19 @@ private:
bool readUInt(UInt64 & value);
template<typename T> bool readFixed(T & value);
bool readStringInto(PaddedPODArray<UInt8> & str);
bool ALWAYS_INLINE maybeCanReadValue() const { return field_end != REACHED_END; }
bool ALWAYS_INLINE maybeCanReadValue() const
{
if (field_end == REACHED_END)
return false;
if (cursor < root_message_end)
return true;
throwUnknownFormat();
}
private:
void readBinary(void* data, size_t size);
void readBinary(void * data, size_t size);
void ignore(UInt64 num_bytes);
void moveCursorBackward(UInt64 num_bytes);
@ -119,6 +128,8 @@ private:
void ignoreVarint();
void ignoreGroup();
[[noreturn]] void throwUnknownFormat() const;
static constexpr UInt64 REACHED_END = 0;
ReadBuffer & in;
@ -126,6 +137,8 @@ private:
std::vector<UInt64> parent_message_ends;
UInt64 current_message_end;
UInt64 field_end;
UInt64 root_message_end;
};
class IConverter