fixed heap buffer overflow in PacketPayloadWriteBuffer

This commit is contained in:
Yuriy 2019-07-29 18:41:47 +03:00
parent a109339ce6
commit c1b57f9cf5
5 changed files with 64 additions and 55 deletions

View File

@ -246,26 +246,17 @@ class PacketPayloadWriteBuffer : public WriteBuffer
{
public:
PacketPayloadWriteBuffer(WriteBuffer & out, size_t payload_length, uint8_t & sequence_id)
: WriteBuffer(out.position(), 0)
, out(out)
, sequence_id(sequence_id)
, total_left(payload_length)
: WriteBuffer(out.position(), 0), out(out), sequence_id(sequence_id), total_left(payload_length)
{
startPacket();
startNewPacket();
setWorkingBuffer();
pos = out.position();
}
void checkPayloadSize()
bool remainingPayloadSize()
{
if (bytes_written + offset() < payload_length)
{
std::stringstream ss;
ss << "Incomplete payload. Written " << bytes << " bytes, expected " << payload_length << " bytes.";
throw Exception(ss.str(), 0);
}
return total_left;
}
~PacketPayloadWriteBuffer() override { next(); }
private:
WriteBuffer & out;
uint8_t & sequence_id;
@ -273,8 +264,9 @@ private:
size_t total_left = 0;
size_t payload_length = 0;
size_t bytes_written = 0;
bool eof = false;
void startPacket()
void startNewPacket()
{
payload_length = std::min(total_left, MAX_PACKET_LENGTH);
bytes_written = 0;
@ -282,33 +274,38 @@ private:
out.write(reinterpret_cast<char *>(&payload_length), 3);
out.write(sequence_id++);
working_buffer = WriteBuffer::Buffer(out.position(), out.position() + std::min(payload_length - bytes_written, out.available()));
pos = working_buffer.begin();
bytes += 4;
}
/// Sets working buffer to the rest of current packet payload.
void setWorkingBuffer()
{
out.nextIfAtEnd();
working_buffer = WriteBuffer::Buffer(out.position(), out.position() + std::min(payload_length - bytes_written, out.available()));
if (payload_length - bytes_written == 0)
{
/// Finished writing packet. Due to an implementation of WriteBuffer, working_buffer cannot be empty. Further write attempts will throw Exception.
eof = true;
working_buffer.resize(1);
}
}
protected:
void nextImpl() override
{
int written = pos - working_buffer.begin();
const int written = pos - working_buffer.begin();
if (eof)
throw Exception("Cannot write after end of buffer.", ErrorCodes::CANNOT_WRITE_AFTER_END_OF_BUFFER);
out.position() += written;
bytes_written += written;
if (bytes_written < payload_length)
{
out.nextIfAtEnd();
working_buffer = WriteBuffer::Buffer(out.position(), out.position() + std::min(payload_length - bytes_written, out.available()));
}
else if (total_left > 0 || payload_length == MAX_PACKET_LENGTH)
{
// Starting new packet, since packets of size greater than MAX_PACKET_LENGTH should be split.
startPacket();
}
else
{
// Finished writing packet. Buffer is set to empty to prevent rewriting (pos will be set to the beginning of a working buffer in next()).
// Further attempts to write will stall in the infinite loop.
working_buffer = WriteBuffer::Buffer(out.position(), out.position());
}
/// Packets of size greater than MAX_PACKET_LENGTH are split into few packets of size MAX_PACKET_LENGTH and las packet of size < MAX_PACKET_LENGTH.
if (bytes_written == payload_length && (total_left > 0 || payload_length == MAX_PACKET_LENGTH))
startNewPacket();
setWorkingBuffer();
}
};
@ -320,7 +317,13 @@ public:
{
PacketPayloadWriteBuffer buf(buffer, getPayloadSize(), sequence_id);
writePayloadImpl(buf);
buf.checkPayloadSize();
buf.next();
if (buf.remainingPayloadSize())
{
std::stringstream ss;
ss << "Incomplete payload. Written " << getPayloadSize() - buf.remainingPayloadSize() << " bytes, expected " << getPayloadSize() << " bytes.";
throw Exception(ss.str(), 0);
}
}
virtual ~WritePacket() = default;

View File

@ -55,16 +55,18 @@ void MySQLWireBlockOutputStream::write(const Block & block)
void MySQLWireBlockOutputStream::writeSuffix()
{
QueryStatus * process_list_elem = context.getProcessListElement();
CurrentThread::finalizePerformanceCounters();
QueryStatusInfo info = process_list_elem->getInfo();
size_t affected_rows = info.written_rows;
size_t affected_rows = 0;
std::stringstream human_readable_info;
human_readable_info << std::fixed << std::setprecision(3)
<< "Read " << info.read_rows << " rows, " << formatReadableSizeWithBinarySuffix(info.read_bytes) << " in " << info.elapsed_seconds << " sec., "
<< static_cast<size_t>(info.read_rows / info.elapsed_seconds) << " rows/sec., "
<< formatReadableSizeWithBinarySuffix(info.read_bytes / info.elapsed_seconds) << "/sec.";
if (QueryStatus * process_list_elem = context.getProcessListElement())
{
CurrentThread::finalizePerformanceCounters();
QueryStatusInfo info = process_list_elem->getInfo();
affected_rows = info.written_rows;
human_readable_info << std::fixed << std::setprecision(3)
<< "Read " << info.read_rows << " rows, " << formatReadableSizeWithBinarySuffix(info.read_bytes) << " in " << info.elapsed_seconds << " sec., "
<< static_cast<size_t>(info.read_rows / info.elapsed_seconds) << " rows/sec., "
<< formatReadableSizeWithBinarySuffix(info.read_bytes / info.elapsed_seconds) << "/sec.";
}
if (header.columns() == 0)
packet_sender.sendPacket(OK_Packet(0x0, context.mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);

View File

@ -71,16 +71,18 @@ void MySQLOutputFormat::consume(Chunk chunk)
void MySQLOutputFormat::finalize()
{
QueryStatus * process_list_elem = context.getProcessListElement();
CurrentThread::finalizePerformanceCounters();
QueryStatusInfo info = process_list_elem->getInfo();
size_t affected_rows = info.written_rows;
size_t affected_rows = 0;
std::stringstream human_readable_info;
human_readable_info << std::fixed << std::setprecision(3)
<< "Read " << info.read_rows << " rows, " << formatReadableSizeWithBinarySuffix(info.read_bytes) << " in " << info.elapsed_seconds << " sec., "
<< static_cast<size_t>(info.read_rows / info.elapsed_seconds) << " rows/sec., "
<< formatReadableSizeWithBinarySuffix(info.read_bytes / info.elapsed_seconds) << "/sec.";
if (QueryStatus * process_list_elem = context.getProcessListElement())
{
CurrentThread::finalizePerformanceCounters();
QueryStatusInfo info = process_list_elem->getInfo();
affected_rows = info.written_rows;
human_readable_info << std::fixed << std::setprecision(3)
<< "Read " << info.read_rows << " rows, " << formatReadableSizeWithBinarySuffix(info.read_bytes) << " in " << info.elapsed_seconds << " sec., "
<< static_cast<size_t>(info.read_rows / info.elapsed_seconds) << " rows/sec., "
<< formatReadableSizeWithBinarySuffix(info.read_bytes / info.elapsed_seconds) << "/sec.";
}
auto & header = getPort(PortKind::Main).getHeader();

View File

@ -35,6 +35,7 @@
<value>Parquet</value>
<value>ODBCDriver2</value>
<value>Null</value>
<value>MySQLWire</value>
</values>
</substitution>
</substitutions>

View File

@ -44,6 +44,7 @@
<value>Native</value>
<value>XML</value>
<value>ODBCDriver2</value>
<value>MySQLWire</value>
</values>
</substitution>
</substitutions>