WIP prepared statements

This commit is contained in:
slvrtrn 2023-08-25 20:31:21 +02:00
parent 7fdb414793
commit 734ffd916c
16 changed files with 1051 additions and 576 deletions

View File

@ -8,254 +8,263 @@ namespace DB
namespace MySQLProtocol
{
namespace Generic
{
static const size_t MYSQL_ERRMSG_SIZE = 512;
void SSLRequest::readPayloadImpl(ReadBuffer & buf)
{
buf.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
buf.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
buf.readStrict(reinterpret_cast<char *>(&character_set), 1);
}
OKPacket::OKPacket(uint32_t capabilities_)
: header(0x00), capabilities(capabilities_), affected_rows(0x00), last_insert_id(0x00), status_flags(0x00)
{
}
OKPacket::OKPacket(
uint8_t header_, uint32_t capabilities_, uint64_t affected_rows_, uint32_t status_flags_, int16_t warnings_,
String session_state_changes_, String info_)
: header(header_), capabilities(capabilities_), affected_rows(affected_rows_), last_insert_id(0), warnings(warnings_),
status_flags(status_flags_), session_state_changes(std::move(session_state_changes_)), info(std::move(info_))
{
}
size_t OKPacket::getPayloadSize() const
{
size_t result = 2 + getLengthEncodedNumberSize(affected_rows);
if (capabilities & CLIENT_PROTOCOL_41)
namespace Generic
{
result += 4;
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
result += 2;
}
if (capabilities & CLIENT_SESSION_TRACK)
{
result += getLengthEncodedStringSize(info);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
result += getLengthEncodedStringSize(session_state_changes);
}
else
{
result += info.size();
}
static const size_t MYSQL_ERRMSG_SIZE = 512;
return result;
}
void OKPacket::readPayloadImpl(ReadBuffer & payload)
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
affected_rows = readLengthEncodedNumber(payload);
last_insert_id = readLengthEncodedNumber(payload);
if (capabilities & CLIENT_PROTOCOL_41)
{
payload.readStrict(reinterpret_cast<char *>(&status_flags), 2);
payload.readStrict(reinterpret_cast<char *>(&warnings), 2);
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
payload.readStrict(reinterpret_cast<char *>(&status_flags), 2);
}
if (capabilities & CLIENT_SESSION_TRACK)
{
readLengthEncodedString(info, payload);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
void SSLRequest::readPayloadImpl(ReadBuffer & buf)
{
readLengthEncodedString(session_state_changes, payload);
buf.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
buf.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
buf.readStrict(reinterpret_cast<char *>(&character_set), 1);
}
}
else
{
readString(info, payload);
}
}
void OKPacket::writePayloadImpl(WriteBuffer & buffer) const
OKPacket::OKPacket(uint32_t capabilities_)
: header(0x00), capabilities(capabilities_), affected_rows(0x00), last_insert_id(0x00), status_flags(0x00)
{
}
{
buffer.write(header);
writeLengthEncodedNumber(affected_rows, buffer);
writeLengthEncodedNumber(last_insert_id, buffer); /// last insert-id
OKPacket::OKPacket(
uint8_t header_,
uint32_t capabilities_,
uint64_t affected_rows_,
uint32_t status_flags_,
int16_t warnings_,
String session_state_changes_,
String info_)
: header(header_)
, capabilities(capabilities_)
, affected_rows(affected_rows_)
, last_insert_id(0)
, warnings(warnings_)
, status_flags(status_flags_)
, session_state_changes(std::move(session_state_changes_))
, info(std::move(info_))
{
}
if (capabilities & CLIENT_PROTOCOL_41)
{
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
buffer.write(reinterpret_cast<const char *>(&warnings), 2);
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
}
size_t OKPacket::getPayloadSize() const
{
size_t result = 2 + getLengthEncodedNumberSize(affected_rows);
if (capabilities & CLIENT_SESSION_TRACK)
{
writeLengthEncodedString(info, buffer);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
writeLengthEncodedString(session_state_changes, buffer);
}
else
{
writeString(info, buffer);
}
}
EOFPacket::EOFPacket() : warnings(0x00), status_flags(0x00)
{
}
EOFPacket::EOFPacket(int warnings_, int status_flags_)
: warnings(warnings_), status_flags(status_flags_)
{
}
size_t EOFPacket::getPayloadSize() const
{
return 5;
}
void EOFPacket::readPayloadImpl(ReadBuffer & payload)
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
assert(header == 0xfe);
payload.readStrict(reinterpret_cast<char *>(&warnings), 2);
payload.readStrict(reinterpret_cast<char *>(&status_flags), 2);
}
void EOFPacket::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(header); // EOF header
buffer.write(reinterpret_cast<const char *>(&warnings), 2);
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
}
void AuthSwitchPacket::readPayloadImpl(ReadBuffer & payload)
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
assert(header == 0xfe);
readStringUntilEOF(plugin_name, payload);
}
ERRPacket::ERRPacket() : error_code(0x00)
{
}
ERRPacket::ERRPacket(int error_code_, String sql_state_, String error_message_)
: error_code(error_code_), sql_state(std::move(sql_state_)), error_message(std::move(error_message_))
{
}
size_t ERRPacket::getPayloadSize() const
{
return 4 + sql_state.length() + std::min(error_message.length(), MYSQL_ERRMSG_SIZE);
}
void ERRPacket::readPayloadImpl(ReadBuffer & payload)
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
assert(header == 0xff);
payload.readStrict(reinterpret_cast<char *>(&error_code), 2);
/// SQL State [optional: # + 5bytes string]
UInt8 sharp = static_cast<unsigned char>(*payload.position());
if (sharp == 0x23)
{
payload.ignore(1);
sql_state.resize(5);
payload.readStrict(reinterpret_cast<char *>(sql_state.data()), 5);
}
readString(error_message, payload);
}
void ERRPacket::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(header);
buffer.write(reinterpret_cast<const char *>(&error_code), 2);
buffer.write('#');
buffer.write(sql_state.data(), sql_state.length());
buffer.write(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE));
}
ResponsePacket::ResponsePacket(UInt32 server_capability_flags_)
: ok(OKPacket(server_capability_flags_))
{
}
ResponsePacket::ResponsePacket(UInt32 server_capability_flags_, bool is_handshake_)
: ok(OKPacket(server_capability_flags_)), is_handshake(is_handshake_)
{
}
void ResponsePacket::readPayloadImpl(ReadBuffer & payload)
{
UInt16 header = static_cast<unsigned char>(*payload.position());
switch (header)
{
case PACKET_OK:
packetType = PACKET_OK;
ok.readPayloadWithUnpacked(payload);
break;
case PACKET_ERR:
packetType = PACKET_ERR;
err.readPayloadWithUnpacked(payload);
break;
case PACKET_EOF:
if (is_handshake)
if (capabilities & CLIENT_PROTOCOL_41)
{
packetType = PACKET_AUTH_SWITCH;
auth_switch.readPayloadWithUnpacked(payload);
result += 4;
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
result += 2;
}
if (capabilities & CLIENT_SESSION_TRACK)
{
result += getLengthEncodedStringSize(info);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
result += getLengthEncodedStringSize(session_state_changes);
}
else
{
packetType = PACKET_EOF;
eof.readPayloadWithUnpacked(payload);
result += info.size();
}
break;
case PACKET_LOCALINFILE:
packetType = PACKET_LOCALINFILE;
break;
default:
packetType = PACKET_OK;
column_length = readLengthEncodedNumber(payload);
return result;
}
void OKPacket::readPayloadImpl(ReadBuffer & payload)
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
affected_rows = readLengthEncodedNumber(payload);
last_insert_id = readLengthEncodedNumber(payload);
if (capabilities & CLIENT_PROTOCOL_41)
{
payload.readStrict(reinterpret_cast<char *>(&status_flags), 2);
payload.readStrict(reinterpret_cast<char *>(&warnings), 2);
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
payload.readStrict(reinterpret_cast<char *>(&status_flags), 2);
}
if (capabilities & CLIENT_SESSION_TRACK)
{
readLengthEncodedString(info, payload);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
{
readLengthEncodedString(session_state_changes, payload);
}
}
else
{
readString(info, payload);
}
}
void OKPacket::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(header);
writeLengthEncodedNumber(affected_rows, buffer);
writeLengthEncodedNumber(last_insert_id, buffer); /// last insert-id
if (capabilities & CLIENT_PROTOCOL_41)
{
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
buffer.write(reinterpret_cast<const char *>(&warnings), 2);
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
}
if (capabilities & CLIENT_SESSION_TRACK)
{
writeLengthEncodedString(info, buffer);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
writeLengthEncodedString(session_state_changes, buffer);
}
else
{
writeString(info, buffer);
}
}
EOFPacket::EOFPacket() : warnings(0x00), status_flags(0x00)
{
}
EOFPacket::EOFPacket(int warnings_, int status_flags_) : warnings(warnings_), status_flags(status_flags_)
{
}
size_t EOFPacket::getPayloadSize() const
{
return 5;
}
void EOFPacket::readPayloadImpl(ReadBuffer & payload)
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
assert(header == 0xfe);
payload.readStrict(reinterpret_cast<char *>(&warnings), 2);
payload.readStrict(reinterpret_cast<char *>(&status_flags), 2);
}
void EOFPacket::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(header); // EOF header
buffer.write(reinterpret_cast<const char *>(&warnings), 2);
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
}
void AuthSwitchPacket::readPayloadImpl(ReadBuffer & payload)
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
assert(header == 0xfe);
readStringUntilEOF(plugin_name, payload);
}
ERRPacket::ERRPacket() : error_code(0x00)
{
}
ERRPacket::ERRPacket(int error_code_, String sql_state_, String error_message_)
: error_code(error_code_), sql_state(std::move(sql_state_)), error_message(std::move(error_message_))
{
}
size_t ERRPacket::getPayloadSize() const
{
return 4 + sql_state.length() + std::min(error_message.length(), MYSQL_ERRMSG_SIZE);
}
void ERRPacket::readPayloadImpl(ReadBuffer & payload)
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
assert(header == 0xff);
payload.readStrict(reinterpret_cast<char *>(&error_code), 2);
/// SQL State [optional: # + 5bytes string]
UInt8 sharp = static_cast<unsigned char>(*payload.position());
if (sharp == 0x23)
{
payload.ignore(1);
sql_state.resize(5);
payload.readStrict(reinterpret_cast<char *>(sql_state.data()), 5);
}
readString(error_message, payload);
}
void ERRPacket::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(header);
buffer.write(reinterpret_cast<const char *>(&error_code), 2);
buffer.write('#');
buffer.write(sql_state.data(), sql_state.length());
buffer.write(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE));
}
ResponsePacket::ResponsePacket(UInt32 server_capability_flags_) : ok(OKPacket(server_capability_flags_))
{
}
ResponsePacket::ResponsePacket(UInt32 server_capability_flags_, bool is_handshake_)
: ok(OKPacket(server_capability_flags_)), is_handshake(is_handshake_)
{
}
void ResponsePacket::readPayloadImpl(ReadBuffer & payload)
{
UInt16 header = static_cast<unsigned char>(*payload.position());
switch (header)
{
case PACKET_OK:
packetType = PACKET_OK;
ok.readPayloadWithUnpacked(payload);
break;
case PACKET_ERR:
packetType = PACKET_ERR;
err.readPayloadWithUnpacked(payload);
break;
case PACKET_EOF:
if (is_handshake)
{
packetType = PACKET_AUTH_SWITCH;
auth_switch.readPayloadWithUnpacked(payload);
}
else
{
packetType = PACKET_EOF;
eof.readPayloadWithUnpacked(payload);
}
break;
case PACKET_LOCALINFILE:
packetType = PACKET_LOCALINFILE;
break;
default:
packetType = PACKET_OK;
column_length = readLengthEncodedNumber(payload);
}
}
LengthEncodedNumber::LengthEncodedNumber(uint64_t value_) : value(value_)
{
}
size_t LengthEncodedNumber::getPayloadSize() const
{
return getLengthEncodedNumberSize(value);
}
void LengthEncodedNumber::writePayloadImpl(WriteBuffer & buffer) const
{
writeLengthEncodedNumber(value, buffer);
}
}
}
LengthEncodedNumber::LengthEncodedNumber(uint64_t value_) : value(value_)
{
}
size_t LengthEncodedNumber::getPayloadSize() const
{
return getLengthEncodedNumberSize(value);
}
void LengthEncodedNumber::writePayloadImpl(WriteBuffer & buffer) const
{
writeLengthEncodedNumber(value, buffer);
}
}
}

View File

@ -0,0 +1,230 @@
#include <Columns/IColumn.h>
#include <Core/MySQL/IMySQLReadPacket.h>
#include <Core/MySQL/IMySQLWritePacket.h>
#include <Core/MySQL/PacketsProtocolBinary.h>
#include <Poco/DateTime.h>
#include <Poco/Timestamp.h>
#include "Columns/ColumnLowCardinality.h"
#include "Columns/ColumnVector.h"
#include "DataTypes/DataTypeLowCardinality.h"
#include "DataTypes/DataTypeNullable.h"
#include "Formats/FormatSettings.h"
#include "IO/WriteBufferFromString.h"
#include "base/types.h"
namespace DB
{
namespace MySQLProtocol
{
namespace ProtocolBinary
{
ResultSetRow::ResultSetRow(
const Serializations & serializations_, const DataTypes & data_types_, const Columns & columns_, int row_num_)
: row_num(row_num_), columns(columns_), data_types(data_types_), serializations(serializations_)
{
/// See https://dev.mysql.com/doc/dev/mysql-server/8.1.0/page_protocol_binary_resultset.html#sect_protocol_binary_resultset_row
payload_size = 1 + null_bitmap_size;
// LOG_TRACE(&Poco::Logger::get("ResultSetRow"), "Null bitmap size: {}", null_bitmap_size);
FormatSettings format_settings;
for (size_t i = 0; i < columns.size(); ++i)
{
ColumnPtr col = columns[i];
if (col->isNullAt(row_num))
{
null_bitmap[i / 8] |= 1 << i % 8;
}
TypeIndex type_index = removeNullable(removeLowCardinality(data_types[i]))->getTypeId();
switch (type_index)
{
case TypeIndex::Int8:
case TypeIndex::UInt8:
payload_size += 1;
break;
case TypeIndex::Int16:
case TypeIndex::UInt16:
payload_size += 2;
break;
case TypeIndex::Int32:
case TypeIndex::UInt32:
case TypeIndex::Float32:
payload_size += 4;
break;
case TypeIndex::Int64:
case TypeIndex::UInt64:
case TypeIndex::Float64:
payload_size += 8;
break;
case TypeIndex::Date: {
UInt64 value = col->get64(row_num);
if (value == 0)
{
payload_size += 1; // length only, no other fields
}
else
{
payload_size += 5;
}
break;
}
case TypeIndex::DateTime: {
UInt64 value = col->get64(row_num);
if (value == 0)
{
payload_size += 1; // length only, no other fields
}
else
{
Poco::DateTime dt = Poco::DateTime(Poco::Timestamp(value * 1000 * 1000));
if (dt.second() == 0 && dt.minute() == 0 && dt.hour() == 0)
{
payload_size += 5;
}
else
{
payload_size += 8;
}
}
break;
}
default:
WriteBufferFromOwnString ostr;
serializations[i]->serializeText(*columns[i], row_num, ostr, format_settings);
payload_size += getLengthEncodedStringSize(ostr.str());
serialized[i] = std::move(ostr.str());
break;
}
}
}
void ResultSetRow::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(static_cast<char>(0x00));
buffer.write(null_bitmap.data(), null_bitmap_size);
for (size_t i = 0; i < columns.size(); ++i)
{
ColumnPtr col = columns[i];
if (col->isNullAt(row_num))
{
continue; // NULLs are stored in the null bitmap only
}
TypeIndex type_index = removeNullable(removeLowCardinality(data_types[i]))->getTypeId();
switch (type_index)
{
case TypeIndex::UInt8: {
UInt64 value = col->get64(row_num);
buffer.write(reinterpret_cast<char *>(&value), 1);
break;
}
case TypeIndex::UInt16: {
UInt64 value = col->get64(row_num);
buffer.write(reinterpret_cast<char *>(&value), 2);
break;
}
case TypeIndex::UInt32: {
UInt64 value = col->get64(row_num);
buffer.write(reinterpret_cast<char *>(&value), 4);
break;
}
case TypeIndex::UInt64: {
UInt64 value = col->get64(row_num);
buffer.write(reinterpret_cast<char *>(&value), 8);
break;
}
case TypeIndex::Int8: {
UInt64 value = col->get64(row_num);
buffer.write(reinterpret_cast<char *>(&value), 1);
break;
}
case TypeIndex::Int16: {
UInt64 value = col->get64(row_num);
buffer.write(reinterpret_cast<char *>(&value), 2);
break;
}
case TypeIndex::Int32: {
UInt64 value = col->get64(row_num);
buffer.write(reinterpret_cast<char *>(&value), 4);
break;
}
case TypeIndex::Int64: {
UInt64 value = col->get64(row_num);
buffer.write(reinterpret_cast<char *>(&value), 8);
break;
}
case TypeIndex::Float32: {
Float32 value = col->getFloat32(row_num);
buffer.write(reinterpret_cast<char *>(&value), 4);
break;
}
case TypeIndex::Float64: {
Float64 value = col->getFloat64(row_num);
buffer.write(reinterpret_cast<char *>(&value), 8);
break;
}
case TypeIndex::Date: {
UInt64 value = col->get64(row_num);
if (value != 0)
{
Poco::DateTime dt = Poco::DateTime(Poco::Timestamp(value * 1000 * 1000));
buffer.write(static_cast<char>(4)); // bytes_following
int year = dt.year();
int month = dt.month();
int day = dt.day();
buffer.write(reinterpret_cast<const char *>(&year), 2);
buffer.write(reinterpret_cast<const char *>(&month), 1);
buffer.write(reinterpret_cast<const char *>(&day), 1);
}
else
{
buffer.write(static_cast<char>(0));
}
break;
}
case TypeIndex::DateTime: {
UInt64 value = col->get64(row_num);
if (value != 0)
{
Poco::DateTime dt = Poco::DateTime(Poco::Timestamp(value * 1000 * 1000));
bool is_date_time = !(dt.hour() == 0 && dt.minute() == 0 && dt.second() == 0);
size_t bytes_following = is_date_time ? 7 : 4;
buffer.write(reinterpret_cast<const char *>(&bytes_following), 1);
int year = dt.year();
int month = dt.month();
int day = dt.day();
buffer.write(reinterpret_cast<const char *>(&year), 2);
buffer.write(reinterpret_cast<const char *>(&month), 1);
buffer.write(reinterpret_cast<const char *>(&day), 1);
if (is_date_time)
{
int hour = dt.hourAMPM();
int minute = dt.minute();
int second = dt.second();
buffer.write(reinterpret_cast<const char *>(&hour), 1);
buffer.write(reinterpret_cast<const char *>(&minute), 1);
buffer.write(reinterpret_cast<const char *>(&second), 1);
}
}
else
{
buffer.write(static_cast<char>(0));
}
break;
}
default:
writeLengthEncodedString(serialized[i], buffer);
break;
}
}
}
size_t ResultSetRow::getPayloadSize() const
{
return payload_size;
};
}
}
}

View File

@ -0,0 +1,45 @@
#pragma once
#include <vector>
#include <Columns/IColumn.h>
#include <Core/MySQL/IMySQLReadPacket.h>
#include <Core/MySQL/IMySQLWritePacket.h>
#include "DataTypes/IDataType.h"
#include "DataTypes/Serializations/ISerialization.h"
namespace DB
{
namespace MySQLProtocol
{
namespace ProtocolBinary
{
class ResultSetRow : public IMySQLWritePacket
{
private:
TypeIndex getTypeIndex(DataTypePtr data_type, const ColumnPtr & col) const;
protected:
int row_num;
const Columns & columns;
const DataTypes & data_types;
const Serializations & serializations;
std::vector<String> serialized = std::vector<String>(columns.size());
size_t null_bitmap_size = (columns.size() + 7) / 8;
std::vector<char> null_bitmap = std::vector<char>(null_bitmap_size, 0);
size_t payload_size = 0;
size_t getPayloadSize() const override;
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
ResultSetRow(const Serializations & serializations_, const DataTypes & data_types_, const Columns & columns_, int row_num_);
};
}
}
}

View File

@ -1,7 +1,8 @@
#include <Core/MySQL/PacketsProtocolText.h>
#include <IO/WriteBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteBufferFromString.h>
#include <IO/WriteHelpers.h>
#include "Core/MySQL/IMySQLWritePacket.h"
namespace DB
{
@ -9,197 +10,212 @@ namespace DB
namespace MySQLProtocol
{
namespace ProtocolText
{
ResultSetRow::ResultSetRow(const Serializations & serializations, const Columns & columns_, int row_num_)
: columns(columns_), row_num(row_num_)
{
for (size_t i = 0; i < columns.size(); ++i)
namespace ProtocolText
{
if (columns[i]->isNullAt(row_num))
ResultSetRow::ResultSetRow(const Serializations & serializations, const Columns & columns_, int row_num_)
: columns(columns_), row_num(row_num_)
{
payload_size += 1;
serialized.emplace_back("\xfb");
for (size_t i = 0; i < columns.size(); ++i)
{
if (columns[i]->isNullAt(row_num))
{
payload_size += 1;
serialized.emplace_back("\xfb");
}
else
{
WriteBufferFromOwnString ostr;
serializations[i]->serializeText(*columns[i], row_num, ostr, FormatSettings());
payload_size += getLengthEncodedStringSize(ostr.str());
serialized.push_back(std::move(ostr.str()));
}
}
}
else
size_t ResultSetRow::getPayloadSize() const
{
WriteBufferFromOwnString ostr;
serializations[i]->serializeText(*columns[i], row_num, ostr, FormatSettings());
payload_size += getLengthEncodedStringSize(ostr.str());
serialized.push_back(std::move(ostr.str()));
return payload_size;
}
void ResultSetRow::writePayloadImpl(WriteBuffer & buffer) const
{
for (size_t i = 0; i < columns.size(); ++i)
{
if (columns[i]->isNullAt(row_num))
buffer.write(serialized[i].data(), 1);
else
writeLengthEncodedString(serialized[i], buffer);
}
}
void ComFieldList::readPayloadImpl(ReadBuffer & payload)
{
// Command byte has been already read from payload.
readNullTerminated(table, payload);
readStringUntilEOF(field_wildcard, payload);
}
ColumnDefinition::ColumnDefinition() : character_set(0x00), column_length(0), column_type(MYSQL_TYPE_DECIMAL), flags(0x00)
{
}
ColumnDefinition::ColumnDefinition(
String schema_,
String table_,
String org_table_,
String name_,
String org_name_,
uint16_t character_set_,
uint32_t column_length_,
ColumnType column_type_,
uint16_t flags_,
uint8_t decimals_,
bool with_defaults_)
: schema(std::move(schema_))
, table(std::move(table_))
, org_table(std::move(org_table_))
, name(std::move(name_))
, org_name(std::move(org_name_))
, character_set(character_set_)
, column_length(column_length_)
, column_type(column_type_)
, flags(flags_)
, decimals(decimals_)
, is_comm_field_list_response(with_defaults_)
{
}
ColumnDefinition::ColumnDefinition(
String name_, uint16_t character_set_, uint32_t column_length_, ColumnType column_type_, uint16_t flags_, uint8_t decimals_)
: ColumnDefinition("", "", "", std::move(name_), "", character_set_, column_length_, column_type_, flags_, decimals_)
{
}
size_t ColumnDefinition::getPayloadSize() const
{
return 12 + getLengthEncodedStringSize("def") + getLengthEncodedStringSize(schema) + getLengthEncodedStringSize(table)
+ getLengthEncodedStringSize(org_table) + getLengthEncodedStringSize(name) + getLengthEncodedStringSize(org_name)
+ getLengthEncodedNumberSize(next_length) + is_comm_field_list_response;
}
void ColumnDefinition::readPayloadImpl(ReadBuffer & payload)
{
String def;
readLengthEncodedString(def, payload);
assert(def == "def");
readLengthEncodedString(schema, payload);
readLengthEncodedString(table, payload);
readLengthEncodedString(org_table, payload);
readLengthEncodedString(name, payload);
readLengthEncodedString(org_name, payload);
next_length = readLengthEncodedNumber(payload);
payload.readStrict(reinterpret_cast<char *>(&character_set), 2);
payload.readStrict(reinterpret_cast<char *>(&column_length), 4);
payload.readStrict(reinterpret_cast<char *>(&column_type), 1);
payload.readStrict(reinterpret_cast<char *>(&flags), 2);
payload.readStrict(reinterpret_cast<char *>(&decimals), 1);
payload.ignore(2);
}
void ColumnDefinition::writePayloadImpl(WriteBuffer & buffer) const
{
writeLengthEncodedString(std::string("def"), buffer); /// always "def"
writeLengthEncodedString(schema, buffer);
writeLengthEncodedString(table, buffer);
writeLengthEncodedString(org_table, buffer);
writeLengthEncodedString(name, buffer);
writeLengthEncodedString(org_name, buffer);
writeLengthEncodedNumber(next_length, buffer);
buffer.write(reinterpret_cast<const char *>(&character_set), 2);
buffer.write(reinterpret_cast<const char *>(&column_length), 4);
buffer.write(reinterpret_cast<const char *>(&column_type), 1);
buffer.write(reinterpret_cast<const char *>(&flags), 2);
buffer.write(reinterpret_cast<const char *>(&decimals), 1);
writeChar(0x0, 2, buffer);
if (is_comm_field_list_response)
{
/// We should write length encoded int with string size
/// followed by string with some "default values" (possibly it's column defaults).
/// But we just send NULL for simplicity.
writeChar(0xfb, buffer);
}
}
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex type_index)
{
ColumnType column_type;
CharacterSet charset = CharacterSet::binary;
int flags = 0;
uint8_t decimals = 0;
switch (type_index)
{
case TypeIndex::UInt8:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt16:
column_type = ColumnType::MYSQL_TYPE_SHORT;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt32:
column_type = ColumnType::MYSQL_TYPE_LONG;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt64:
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::Int8:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int16:
column_type = ColumnType::MYSQL_TYPE_SHORT;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int32:
column_type = ColumnType::MYSQL_TYPE_LONG;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int64:
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Float32:
column_type = ColumnType::MYSQL_TYPE_FLOAT;
flags = ColumnDefinitionFlags::BINARY_FLAG;
decimals = 31;
break;
case TypeIndex::Float64:
column_type = ColumnType::MYSQL_TYPE_DOUBLE;
flags = ColumnDefinitionFlags::BINARY_FLAG;
decimals = 31;
break;
case TypeIndex::Date:
column_type = ColumnType::MYSQL_TYPE_DATE;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::DateTime:
column_type = ColumnType::MYSQL_TYPE_DATETIME;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Decimal32:
case TypeIndex::Decimal64:
case TypeIndex::Decimal128:
/// MySQL Decimal has max 65 precision and 30 scale. Thus, Decimal256 is reported as a string
column_type = ColumnType::MYSQL_TYPE_DECIMAL;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
default:
column_type = ColumnType::MYSQL_TYPE_STRING;
charset = CharacterSet::utf8_general_ci;
break;
}
return ColumnDefinition(column_name, charset, 0, column_type, flags, decimals);
}
}
}
size_t ResultSetRow::getPayloadSize() const
{
return payload_size;
}
void ResultSetRow::writePayloadImpl(WriteBuffer & buffer) const
{
for (size_t i = 0; i < columns.size(); ++i)
{
if (columns[i]->isNullAt(row_num))
buffer.write(serialized[i].data(), 1);
else
writeLengthEncodedString(serialized[i], buffer);
}
}
void ComFieldList::readPayloadImpl(ReadBuffer & payload)
{
// Command byte has been already read from payload.
readNullTerminated(table, payload);
readStringUntilEOF(field_wildcard, payload);
}
ColumnDefinition::ColumnDefinition()
: character_set(0x00), column_length(0), column_type(MYSQL_TYPE_DECIMAL), flags(0x00)
{
}
ColumnDefinition::ColumnDefinition(
String schema_, String table_, String org_table_, String name_, String org_name_, uint16_t character_set_, uint32_t column_length_,
ColumnType column_type_, uint16_t flags_, uint8_t decimals_, bool with_defaults_)
: schema(std::move(schema_)), table(std::move(table_)), org_table(std::move(org_table_)), name(std::move(name_)),
org_name(std::move(org_name_)), character_set(character_set_), column_length(column_length_), column_type(column_type_),
flags(flags_), decimals(decimals_), is_comm_field_list_response(with_defaults_)
{
}
ColumnDefinition::ColumnDefinition(
String name_, uint16_t character_set_, uint32_t column_length_, ColumnType column_type_, uint16_t flags_, uint8_t decimals_)
: ColumnDefinition("", "", "", std::move(name_), "", character_set_, column_length_, column_type_, flags_, decimals_)
{
}
size_t ColumnDefinition::getPayloadSize() const
{
return 12 +
getLengthEncodedStringSize("def") +
getLengthEncodedStringSize(schema) +
getLengthEncodedStringSize(table) +
getLengthEncodedStringSize(org_table) +
getLengthEncodedStringSize(name) +
getLengthEncodedStringSize(org_name) +
getLengthEncodedNumberSize(next_length) +
is_comm_field_list_response;
}
void ColumnDefinition::readPayloadImpl(ReadBuffer & payload)
{
String def;
readLengthEncodedString(def, payload);
assert(def == "def");
readLengthEncodedString(schema, payload);
readLengthEncodedString(table, payload);
readLengthEncodedString(org_table, payload);
readLengthEncodedString(name, payload);
readLengthEncodedString(org_name, payload);
next_length = readLengthEncodedNumber(payload);
payload.readStrict(reinterpret_cast<char *>(&character_set), 2);
payload.readStrict(reinterpret_cast<char *>(&column_length), 4);
payload.readStrict(reinterpret_cast<char *>(&column_type), 1);
payload.readStrict(reinterpret_cast<char *>(&flags), 2);
payload.readStrict(reinterpret_cast<char *>(&decimals), 1);
payload.ignore(2);
}
void ColumnDefinition::writePayloadImpl(WriteBuffer & buffer) const
{
writeLengthEncodedString(std::string("def"), buffer); /// always "def"
writeLengthEncodedString(schema, buffer);
writeLengthEncodedString(table, buffer);
writeLengthEncodedString(org_table, buffer);
writeLengthEncodedString(name, buffer);
writeLengthEncodedString(org_name, buffer);
writeLengthEncodedNumber(next_length, buffer);
buffer.write(reinterpret_cast<const char *>(&character_set), 2);
buffer.write(reinterpret_cast<const char *>(&column_length), 4);
buffer.write(reinterpret_cast<const char *>(&column_type), 1);
buffer.write(reinterpret_cast<const char *>(&flags), 2);
buffer.write(reinterpret_cast<const char *>(&decimals), 1);
writeChar(0x0, 2, buffer);
if (is_comm_field_list_response)
{
/// We should write length encoded int with string size
/// followed by string with some "default values" (possibly it's column defaults).
/// But we just send NULL for simplicity.
writeChar(0xfb, buffer);
}
}
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex type_index)
{
ColumnType column_type;
CharacterSet charset = CharacterSet::binary;
int flags = 0;
switch (type_index)
{
case TypeIndex::UInt8:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt16:
column_type = ColumnType::MYSQL_TYPE_SHORT;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt32:
column_type = ColumnType::MYSQL_TYPE_LONG;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt64:
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::Int8:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int16:
column_type = ColumnType::MYSQL_TYPE_SHORT;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int32:
column_type = ColumnType::MYSQL_TYPE_LONG;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int64:
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Float32:
column_type = ColumnType::MYSQL_TYPE_FLOAT;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Float64:
column_type = ColumnType::MYSQL_TYPE_DOUBLE;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Date:
column_type = ColumnType::MYSQL_TYPE_DATE;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::DateTime:
column_type = ColumnType::MYSQL_TYPE_DATETIME;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::String:
case TypeIndex::FixedString:
column_type = ColumnType::MYSQL_TYPE_STRING;
charset = CharacterSet::utf8_general_ci;
break;
default:
column_type = ColumnType::MYSQL_TYPE_STRING;
charset = CharacterSet::utf8_general_ci;
break;
}
return ColumnDefinition(column_name, charset, 0, column_type, flags, 0);
}
}
}

View File

@ -1,13 +1,13 @@
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/Serializations/SerializationDecimal.h>
#include <Common/typeid_cast.h>
#include <Core/DecimalFunctions.h>
#include <DataTypes/DataTypeFactory.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <IO/readDecimalText.h>
#include <Parsers/ASTLiteral.h>
#include <Common/typeid_cast.h>
#include <type_traits>
@ -31,6 +31,12 @@ std::string DataTypeDecimal<T>::doGetName() const
template <is_decimal T>
std::string DataTypeDecimal<T>::getSQLCompatibleName() const
{
/// See https://dev.mysql.com/doc/refman/8.0/en/precision-math-decimal-characteristics.html
/// DECIMAL(M,D)
/// M is the maximum number of digits (the precision). It has a range of 1 to 65.
/// D is the number of digits to the right of the decimal point (the scale). It has a range of 0 to 30 and must be no larger than M.
if (this->precision > 65 || this->scale > 30)
return "TEXT";
return fmt::format("DECIMAL({}, {})", this->precision, this->scale);
}
@ -75,14 +81,14 @@ SerializationPtr DataTypeDecimal<T>::doGetDefaultSerialization() const
static DataTypePtr create(const ASTPtr & arguments)
{
if (!arguments || arguments->children.size() != 2)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Decimal data type family must have exactly two arguments: precision and scale");
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Decimal data type family must have exactly two arguments: precision and scale");
const auto * precision = arguments->children[0]->as<ASTLiteral>();
const auto * scale = arguments->children[1]->as<ASTLiteral>();
if (!precision || precision->value.getType() != Field::Types::UInt64 ||
!scale || !(scale->value.getType() == Field::Types::Int64 || scale->value.getType() == Field::Types::UInt64))
if (!precision || precision->value.getType() != Field::Types::UInt64 || !scale
|| !(scale->value.getType() == Field::Types::Int64 || scale->value.getType() == Field::Types::UInt64))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Decimal data type family must have two numbers as its arguments");
UInt64 precision_value = precision->value.get<UInt64>();
@ -95,13 +101,15 @@ template <typename T>
static DataTypePtr createExact(const ASTPtr & arguments)
{
if (!arguments || arguments->children.size() != 1)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Decimal32 | Decimal64 | Decimal128 | Decimal256 data type family must have exactly one arguments: scale");
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Decimal32 | Decimal64 | Decimal128 | Decimal256 data type family must have exactly one arguments: scale");
const auto * scale_arg = arguments->children[0]->as<ASTLiteral>();
if (!scale_arg || !(scale_arg->value.getType() == Field::Types::Int64 || scale_arg->value.getType() == Field::Types::UInt64))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Decimal32 | Decimal64 | Decimal128 | Decimal256 data type family must have a one number as its argument");
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Decimal32 | Decimal64 | Decimal128 | Decimal256 data type family must have a one number as its argument");
UInt64 precision = DecimalUtils::max_precision<T>;
UInt64 scale = scale_arg->value.get<UInt64>();

View File

@ -1,7 +1,7 @@
#pragma once
#include <Core/Names.h>
#include <Core/Defines.h>
#include <Core/Names.h>
#include <base/types.h>
#include <base/unit.h>
@ -48,9 +48,9 @@ struct FormatSettings
enum class DateTimeInputFormat
{
Basic, /// Default format for fast parsing: YYYY-MM-DD hh:mm:ss (ISO-8601 without fractional part and timezone) or NNNNNNNNNN unix timestamp.
BestEffort, /// Use sophisticated rules to parse whatever possible.
BestEffortUS /// Use sophisticated rules to parse American style: mm/dd/yyyy
Basic, /// Default format for fast parsing: YYYY-MM-DD hh:mm:ss (ISO-8601 without fractional part and timezone) or NNNNNNNNNN unix timestamp.
BestEffort, /// Use sophisticated rules to parse whatever possible.
BestEffortUS /// Use sophisticated rules to parse American style: mm/dd/yyyy
};
DateTimeInputFormat date_time_input_format = DateTimeInputFormat::Basic;
@ -282,6 +282,14 @@ struct FormatSettings
uint32_t client_capabilities = 0;
size_t max_packet_size = 0;
uint8_t * sequence_id = nullptr; /// Not null if it's MySQLWire output format used to handle MySQL protocol connections.
/**
* COM_QUERY uses Text ResultSet
* https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html
* COM_STMT_EXECUTE uses Binary Protocol ResultSet
* https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute_response.html
* By default, use Text ResultSet.
*/
bool use_binary_result_set = false;
} mysql_wire;
struct

View File

@ -37,7 +37,7 @@ String InterpreterShowColumnsQuery::getRewrittenQuery()
SELECT
name AS field,
type AS type,
startsWith(type, 'Nullable') AS null,
if(startsWith(type, 'Nullable'), 'YES', 'NO') AS null,
trim(concatWithSeparator(' ', if (is_in_primary_key, 'PRI', ''), if (is_in_sorting_key, 'SOR', ''))) AS key,
if (default_kind IN ('ALIAS', 'DEFAULT', 'MATERIALIZED'), default_expression, NULL) AS default,
'' AS extra )";

View File

@ -1,11 +1,12 @@
#include <Processors/Formats/Impl/MySQLOutputFormat.h>
#include <Core/MySQL/PacketsGeneric.h>
#include <Core/MySQL/PacketsProtocolBinary.h>
#include <Core/MySQL/PacketsProtocolText.h>
#include <Formats/FormatFactory.h>
#include <Formats/FormatSettings.h>
#include <Interpreters/Context.h>
#include <Interpreters/ProcessList.h>
#include <Processors/Formats/Impl/MySQLOutputFormat.h>
#include "Common/logger_useful.h"
namespace DB
{
@ -13,17 +14,18 @@ namespace DB
using namespace MySQLProtocol;
using namespace MySQLProtocol::Generic;
using namespace MySQLProtocol::ProtocolText;
using namespace MySQLProtocol::ProtocolBinary;
MySQLOutputFormat::MySQLOutputFormat(WriteBuffer & out_, const Block & header_, const FormatSettings & settings_)
: IOutputFormat(header_, out_)
, client_capabilities(settings_.mysql_wire.client_capabilities)
: IOutputFormat(header_, out_), client_capabilities(settings_.mysql_wire.client_capabilities)
{
/// MySQlWire is a special format that is usually used as output format for MySQL protocol connections.
/// In this case we have a correct `sequence_id` stored in `settings_.mysql_wire`.
/// But it's also possible to specify MySQLWire as output format for clickhouse-client or clickhouse-local.
/// There is no `sequence_id` stored in `settings_.mysql_wire` in this case, so we create a dummy one.
sequence_id = settings_.mysql_wire.sequence_id ? settings_.mysql_wire.sequence_id : &dummy_sequence_id;
/// Switch between Text (COM_QUERY) and Binary (COM_EXECUTE_STMT) ResultSet
use_binary_result_set = settings_.mysql_wire.use_binary_result_set;
const auto & header = getPort(PortKind::Main).getHeader();
data_types = header.getDataTypes();
@ -54,7 +56,7 @@ void MySQLOutputFormat::writePrefix()
packet_endpoint->sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId()));
}
if (!(client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
if (!(client_capabilities & Capability::CLIENT_DEPRECATE_EOF) && !use_binary_result_set)
{
packet_endpoint->sendPacket(EOFPacket(0, 0));
}
@ -63,39 +65,67 @@ void MySQLOutputFormat::writePrefix()
void MySQLOutputFormat::consume(Chunk chunk)
{
for (size_t i = 0; i < chunk.getNumRows(); ++i)
if (!use_binary_result_set)
{
ProtocolText::ResultSetRow row_packet(serializations, chunk.getColumns(), static_cast<int>(i));
packet_endpoint->sendPacket(row_packet);
for (size_t i = 0; i < chunk.getNumRows(); ++i)
{
ProtocolText::ResultSetRow row_packet(serializations, chunk.getColumns(), static_cast<int>(i));
packet_endpoint->sendPacket(row_packet);
}
}
else
{
for (size_t i = 0; i < chunk.getNumRows(); ++i)
{
ProtocolBinary::ResultSetRow row_packet(serializations, data_types, chunk.getColumns(), static_cast<int>(i));
packet_endpoint->sendPacket(row_packet);
}
}
}
void MySQLOutputFormat::finalizeImpl()
{
size_t affected_rows = 0;
std::string human_readable_info;
if (QueryStatusPtr process_list_elem = getContext()->getProcessListElement())
if (!use_binary_result_set)
{
CurrentThread::finalizePerformanceCounters();
QueryStatusInfo info = process_list_elem->getInfo();
affected_rows = info.written_rows;
double elapsed_seconds = static_cast<double>(info.elapsed_microseconds) / 1000000.0;
human_readable_info = fmt::format(
"Read {} rows, {} in {} sec., {} rows/sec., {}/sec.",
info.read_rows,
ReadableSize(info.read_bytes),
elapsed_seconds,
static_cast<size_t>(info.read_rows / elapsed_seconds),
ReadableSize(info.read_bytes / elapsed_seconds));
}
size_t affected_rows = 0;
std::string human_readable_info;
if (QueryStatusPtr process_list_elem = getContext()->getProcessListElement())
{
CurrentThread::finalizePerformanceCounters();
QueryStatusInfo info = process_list_elem->getInfo();
affected_rows = info.written_rows;
double elapsed_seconds = static_cast<double>(info.elapsed_microseconds) / 1000000.0;
human_readable_info = fmt::format(
"Read {} rows, {} in {} sec., {} rows/sec., {}/sec.",
info.read_rows,
ReadableSize(info.read_bytes),
elapsed_seconds,
static_cast<size_t>(info.read_rows / elapsed_seconds),
ReadableSize(info.read_bytes / elapsed_seconds));
}
const auto & header = getPort(PortKind::Main).getHeader();
if (header.columns() == 0)
packet_endpoint->sendPacket(OKPacket(0x0, client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
else if (client_capabilities & CLIENT_DEPRECATE_EOF)
packet_endpoint->sendPacket(OKPacket(0xfe, client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
const auto & header = getPort(PortKind::Main).getHeader();
if (header.columns() == 0)
packet_endpoint->sendPacket(OKPacket(0x0, client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
else if (client_capabilities & CLIENT_DEPRECATE_EOF)
packet_endpoint->sendPacket(OKPacket(0xfe, client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
else
packet_endpoint->sendPacket(EOFPacket(0, 0), true);
}
else
packet_endpoint->sendPacket(EOFPacket(0, 0), true);
{
size_t affected_rows = 0;
if (QueryStatusPtr process_list_elem = getContext()->getProcessListElement())
{
CurrentThread::finalizePerformanceCounters();
QueryStatusInfo info = process_list_elem->getInfo();
affected_rows = info.written_rows;
}
if (client_capabilities & CLIENT_DEPRECATE_EOF)
packet_endpoint->sendPacket(OKPacket(0xfe, client_capabilities, affected_rows, 0, 0, "", ""), true);
else
packet_endpoint->sendPacket(EOFPacket(0, 0), true);
}
}
void MySQLOutputFormat::flush()
@ -107,9 +137,8 @@ void registerOutputFormatMySQLWire(FormatFactory & factory)
{
factory.registerOutputFormat(
"MySQLWire",
[](WriteBuffer & buf,
const Block & sample,
const FormatSettings & settings) { return std::make_shared<MySQLOutputFormat>(buf, sample, settings); });
[](WriteBuffer & buf, const Block & sample, const FormatSettings & settings)
{ return std::make_shared<MySQLOutputFormat>(buf, sample, settings); });
}
}

View File

@ -1,7 +1,7 @@
#pragma once
#include <Processors/Formats/IRowOutputFormat.h>
#include <Core/Block.h>
#include <Processors/Formats/IRowOutputFormat.h>
#include <Core/MySQL/PacketEndpoint.h>
#include <Processors/Formats/IOutputFormat.h>
@ -39,6 +39,7 @@ private:
MySQLProtocol::PacketEndpointPtr packet_endpoint;
DataTypes data_types;
Serializations serializations;
bool use_binary_result_set = false;
};
}

View File

@ -1,29 +1,29 @@
#include "MySQLHandler.h"
#include <limits>
#include <Common/NetException.h>
#include <Common/OpenSSLHelpers.h>
#include <regex>
#include <Core/MySQL/Authentication.h>
#include <Core/MySQL/PacketsConnection.h>
#include <Core/MySQL/PacketsGeneric.h>
#include <Core/MySQL/PacketsPreparedStatements.h>
#include <Core/MySQL/PacketsConnection.h>
#include <Core/MySQL/PacketsProtocolText.h>
#include <Core/NamesAndTypes.h>
#include <Interpreters/Session.h>
#include <Interpreters/executeQuery.h>
#include <IO/copyData.h>
#include <IO/LimitReadBuffer.h>
#include <IO/ReadBufferFromPocoSocket.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteBufferFromPocoSocket.h>
#include <IO/WriteBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <IO/copyData.h>
#include <Interpreters/Session.h>
#include <Interpreters/executeQuery.h>
#include <Server/TCPServer.h>
#include <Storages/IStorage.h>
#include <regex>
#include <Common/setThreadName.h>
#include <Core/MySQL/Authentication.h>
#include <Common/logger_useful.h>
#include <base/scope_guard.h>
#include <Common/NetException.h>
#include <Common/OpenSSLHelpers.h>
#include <Common/logger_useful.h>
#include <Common/setThreadName.h>
#include "config_version.h"
@ -67,10 +67,7 @@ static String killConnectionIdReplacementQuery(const String & query);
static String selectLimitReplacementQuery(const String & query);
MySQLHandler::MySQLHandler(
IServer & server_,
TCPServer & tcp_server_,
const Poco::Net::StreamSocket & socket_,
bool ssl_enabled, uint32_t connection_id_)
IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, bool ssl_enabled, uint32_t connection_id_)
: Poco::Net::TCPServerConnection(socket_)
, server(server_)
, tcp_server(tcp_server_)
@ -78,7 +75,8 @@ MySQLHandler::MySQLHandler(
, connection_id(connection_id_)
, auth_plugin(new MySQLProtocol::Authentication::Native41())
{
server_capabilities = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF;
server_capabilities = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
| CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF;
if (ssl_enabled)
server_capabilities |= CLIENT_SSL;
@ -104,8 +102,13 @@ void MySQLHandler::run()
try
{
Handshake handshake(server_capabilities, connection_id, VERSION_STRING + String("-") + VERSION_NAME,
auth_plugin->getName(), auth_plugin->getAuthPluginData(), CharacterSet::utf8_general_ci);
Handshake handshake(
server_capabilities,
connection_id,
VERSION_STRING + String("-") + VERSION_NAME,
auth_plugin->getName(),
auth_plugin->getAuthPluginData(),
CharacterSet::utf8_general_ci);
packet_endpoint->sendPacket<Handshake>(handshake, true);
LOG_TRACE(log, "Sent handshake");
@ -115,8 +118,10 @@ void MySQLHandler::run()
client_capabilities = handshake_response.capability_flags;
max_packet_size = handshake_response.max_packet_size ? handshake_response.max_packet_size : MAX_PACKET_LENGTH;
LOG_TRACE(log,
"Capabilities: {}, max_packet_size: {}, character_set: {}, user: {}, auth_response length: {}, database: {}, auth_plugin_name: {}",
LOG_TRACE(
log,
"Capabilities: {}, max_packet_size: {}, character_set: {}, user: {}, auth_response length: {}, database: {}, auth_plugin_name: "
"{}",
handshake_response.capability_flags,
handshake_response.max_packet_size,
static_cast<int>(handshake_response.character_set),
@ -160,8 +165,8 @@ void MySQLHandler::run()
// For commands which are executed without MemoryTracker.
LimitReadBuffer limited_payload(payload, 10000, /* trow_exception */ true, /* exact_limit */ {}, "too long MySQL packet.");
LOG_DEBUG(log, "Received command: {}. Connection id: {}.",
static_cast<int>(static_cast<unsigned char>(command)), connection_id);
LOG_DEBUG(
log, "Received command: {}. Connection id: {}.", static_cast<int>(static_cast<unsigned char>(command)), connection_id);
if (!tcp_server.isOpen())
return;
@ -175,7 +180,7 @@ void MySQLHandler::run()
comInitDB(limited_payload);
break;
case COM_QUERY:
comQuery(payload);
comQuery(payload, false);
break;
case COM_FIELD_LIST:
comFieldList(limited_payload);
@ -227,13 +232,15 @@ void MySQLHandler::finishHandshake(MySQLProtocol::ConnectionPhase::HandshakeResp
size_t pos = 0;
/// Reads at least count and at most packet_size bytes.
auto read_bytes = [this, &buf, &pos, &packet_size](size_t count) -> void {
auto read_bytes = [this, &buf, &pos, &packet_size](size_t count) -> void
{
while (pos < count)
{
int ret = socket().receiveBytes(buf + pos, static_cast<uint32_t>(packet_size - pos));
if (ret == 0)
{
throw Exception(ErrorCodes::CANNOT_READ_ALL_DATA, "Cannot read all data. Bytes read: {}. Bytes expected: 3", std::to_string(pos));
throw Exception(
ErrorCodes::CANNOT_READ_ALL_DATA, "Cannot read all data. Bytes read: {}. Bytes expected: 3", std::to_string(pos));
}
pos += ret;
}
@ -272,7 +279,8 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl
authPluginSSL();
}
std::optional<String> auth_response = auth_plugin_name == auth_plugin->getName() ? std::make_optional<String>(initial_auth_response) : std::nullopt;
std::optional<String> auth_response
= auth_plugin_name == auth_plugin->getName() ? std::make_optional<String>(initial_auth_response) : std::nullopt;
auth_plugin->authenticate(user_name, *session, auth_response, packet_endpoint, secure_connection, socket().peerAddress());
}
catch (const Exception & exc)
@ -304,8 +312,17 @@ void MySQLHandler::comFieldList(ReadBuffer & payload)
for (const NameAndTypePair & column : metadata_snapshot->getColumns().getAll())
{
ColumnDefinition column_definition(
database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0, true
);
database,
packet.table,
packet.table,
column.name,
column.name,
CharacterSet::binary,
100,
ColumnType::MYSQL_TYPE_STRING,
0,
0,
true);
packet_endpoint->sendPacket(column_definition);
}
packet_endpoint->sendPacket(OKPacket(0xfe, client_capabilities, 0, 0, 0), true);
@ -318,7 +335,7 @@ void MySQLHandler::comPing()
static bool isFederatedServerSetupSetCommand(const String & query);
void MySQLHandler::comQuery(ReadBuffer & payload)
void MySQLHandler::comQuery(ReadBuffer & payload, bool use_binary_protocol_result_set)
{
String query = String(payload.position(), payload.buffer().end());
@ -350,20 +367,22 @@ void MySQLHandler::comQuery(ReadBuffer & payload)
query_context->setCurrentQueryId(fmt::format("mysql:{}:{}", connection_id, toString(UUIDHelpers::generateV4())));
CurrentThread::QueryScope query_scope{query_context};
std::atomic<size_t> affected_rows {0};
std::atomic<size_t> affected_rows{0};
auto prev = query_context->getProgressCallback();
query_context->setProgressCallback([&, my_prev = prev](const Progress & progress)
{
if (my_prev)
my_prev(progress);
query_context->setProgressCallback(
[&, my_prev = prev](const Progress & progress)
{
if (my_prev)
my_prev(progress);
affected_rows += progress.written_rows;
});
affected_rows += progress.written_rows;
});
FormatSettings format_settings;
format_settings.mysql_wire.client_capabilities = client_capabilities;
format_settings.mysql_wire.max_packet_size = max_packet_size;
format_settings.mysql_wire.sequence_id = &sequence_id;
format_settings.mysql_wire.use_binary_result_set = use_binary_protocol_result_set;
auto set_result_details = [&with_output](const QueryResultDetails & details)
{
@ -385,11 +404,18 @@ void MySQLHandler::comQuery(ReadBuffer & payload)
void MySQLHandler::comStmtPrepare(DB::ReadBuffer & payload)
{
if (prepared_statements_map.size() > 10000) /// Shouldn't happen in reality as COM_STMT_CLOSE cleans up the elements
{
LOG_ERROR(log, "Too many prepared statements");
packet_endpoint->sendPacket(ERRPacket(), true);
return;
}
String query;
readStringUntilEOF(query, payload);
uint32_t statement_id = current_prepared_statement_id;
if (current_prepared_statement_id == std::numeric_limits<uint32_t>::max()) [[unlikely]]
if (current_prepared_statement_id == std::numeric_limits<uint32_t>::max())
{
current_prepared_statement_id = 0;
}
@ -400,7 +426,7 @@ void MySQLHandler::comStmtPrepare(DB::ReadBuffer & payload)
// Key collisions should not happen here, as we remove the elements from the map with COM_STMT_CLOSE,
// and we have quite a big range of available identifiers with 32-bit unsigned integer
if (prepared_statements_map.contains(statement_id)) [[unlikely]]
if (prepared_statements_map.contains(statement_id))
{
LOG_ERROR(
log,
@ -411,8 +437,8 @@ void MySQLHandler::comStmtPrepare(DB::ReadBuffer & payload)
packet_endpoint->sendPacket(ERRPacket(), true);
return;
}
prepared_statements_map.emplace(statement_id, query);
prepared_statements_map.emplace(statement_id, query);
packet_endpoint->sendPacket(PrepareStatementResponseOK(statement_id, 0, 0, 0), true);
}
@ -421,7 +447,7 @@ void MySQLHandler::comStmtExecute(ReadBuffer & payload)
uint32_t statement_id;
payload.readStrict(reinterpret_cast<char *>(&statement_id), 4);
if (!prepared_statements_map.contains(statement_id)) [[unlikely]]
if (!prepared_statements_map.contains(statement_id))
{
LOG_ERROR(log, "Could not find prepared statement with id {}", statement_id);
packet_endpoint->sendPacket(ERRPacket(), true);
@ -430,14 +456,16 @@ void MySQLHandler::comStmtExecute(ReadBuffer & payload)
// Temporary workaround as we work only with queries that do not bind any parameters atm
ReadBufferFromString com_query_payload(prepared_statements_map.at(statement_id));
MySQLHandler::comQuery(com_query_payload);
MySQLHandler::comQuery(com_query_payload, true);
};
void MySQLHandler::comStmtClose([[maybe_unused]] ReadBuffer & payload) {
void MySQLHandler::comStmtClose(ReadBuffer & payload)
{
uint32_t statement_id;
payload.readStrict(reinterpret_cast<char *>(&statement_id), 4);
if (prepared_statements_map.contains(statement_id)) {
if (prepared_statements_map.contains(statement_id))
{
prepared_statements_map.erase(statement_id);
}
@ -447,13 +475,17 @@ void MySQLHandler::comStmtClose([[maybe_unused]] ReadBuffer & payload) {
void MySQLHandler::authPluginSSL()
{
throw Exception(ErrorCodes::SUPPORT_IS_DISABLED,
"ClickHouse was built without SSL support. Try specifying password using double SHA1 in users.xml.");
throw Exception(
ErrorCodes::SUPPORT_IS_DISABLED,
"ClickHouse was built without SSL support. Try specifying password using double SHA1 in users.xml.");
}
void MySQLHandler::finishHandshakeSSL(
[[maybe_unused]] size_t packet_size, [[maybe_unused]] char * buf, [[maybe_unused]] size_t pos,
[[maybe_unused]] std::function<void(size_t)> read_bytes, [[maybe_unused]] MySQLProtocol::ConnectionPhase::HandshakeResponse & packet)
[[maybe_unused]] size_t packet_size,
[[maybe_unused]] char * buf,
[[maybe_unused]] size_t pos,
[[maybe_unused]] std::function<void(size_t)> read_bytes,
[[maybe_unused]] MySQLProtocol::ConnectionPhase::HandshakeResponse & packet)
{
throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "Client requested SSL, while it is disabled.");
}
@ -467,10 +499,9 @@ MySQLHandlerSSL::MySQLHandlerSSL(
uint32_t connection_id_,
RSA & public_key_,
RSA & private_key_)
: MySQLHandler(server_, tcp_server_, socket_, ssl_enabled, connection_id_)
, public_key(public_key_)
, private_key(private_key_)
{}
: MySQLHandler(server_, tcp_server_, socket_, ssl_enabled, connection_id_), public_key(public_key_), private_key(private_key_)
{
}
void MySQLHandlerSSL::authPluginSSL()
{
@ -478,7 +509,10 @@ void MySQLHandlerSSL::authPluginSSL()
}
void MySQLHandlerSSL::finishHandshakeSSL(
size_t packet_size, char *buf, size_t pos, std::function<void(size_t)> read_bytes,
size_t packet_size,
char * buf,
size_t pos,
std::function<void(size_t)> read_bytes,
MySQLProtocol::ConnectionPhase::HandshakeResponse & packet)
{
read_bytes(packet_size); /// Reading rest SSLRequest.
@ -508,8 +542,8 @@ static bool isFederatedServerSetupSetCommand(const String & query)
"|(^(SET AUTOCOMMIT(.*)))"
"|(^(SET sql_mode(.*)))"
"|(^(SET @@(.*)))"
"|(^(SET SESSION TRANSACTION ISOLATION LEVEL(.*)))"
, std::regex::icase};
"|(^(SET SESSION TRANSACTION ISOLATION LEVEL(.*)))",
std::regex::icase};
return 1 == std::regex_match(query, expr);
}

View File

@ -1,12 +1,12 @@
#pragma once
#include <Poco/Net/TCPServerConnection.h>
#include <base/getFQDNOrHostName.h>
#include <Common/CurrentMetrics.h>
#include <Core/MySQL/Authentication.h>
#include <Core/MySQL/PacketsGeneric.h>
#include <Core/MySQL/PacketsConnection.h>
#include <Core/MySQL/PacketsGeneric.h>
#include <Core/MySQL/PacketsProtocolText.h>
#include <base/getFQDNOrHostName.h>
#include <Poco/Net/TCPServerConnection.h>
#include <Common/CurrentMetrics.h>
#include "IServer.h"
#include "config.h"
@ -19,7 +19,7 @@
namespace CurrentMetrics
{
extern const Metric MySQLConnection;
extern const Metric MySQLConnection;
}
namespace DB
@ -32,11 +32,7 @@ class MySQLHandler : public Poco::Net::TCPServerConnection
{
public:
MySQLHandler(
IServer & server_,
TCPServer & tcp_server_,
const Poco::Net::StreamSocket & socket_,
bool ssl_enabled,
uint32_t connection_id_);
IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, bool ssl_enabled, uint32_t connection_id_);
void run() final;
@ -46,7 +42,7 @@ protected:
/// Enables SSL, if client requested.
void finishHandshake(MySQLProtocol::ConnectionPhase::HandshakeResponse &);
void comQuery(ReadBuffer & payload);
void comQuery(ReadBuffer & payload, bool use_binary_protocol_result_set);
void comFieldList(ReadBuffer & payload);
@ -63,7 +59,12 @@ protected:
void comStmtClose(ReadBuffer & payload);
virtual void authPluginSSL();
virtual void finishHandshakeSSL(size_t packet_size, char * buf, size_t pos, std::function<void(size_t)> read_bytes, MySQLProtocol::ConnectionPhase::HandshakeResponse & packet);
virtual void finishHandshakeSSL(
size_t packet_size,
char * buf,
size_t pos,
std::function<void(size_t)> read_bytes,
MySQLProtocol::ConnectionPhase::HandshakeResponse & packet);
IServer & server;
TCPServer & tcp_server;
@ -109,8 +110,11 @@ private:
void authPluginSSL() override;
void finishHandshakeSSL(
size_t packet_size, char * buf, size_t pos,
std::function<void(size_t)> read_bytes, MySQLProtocol::ConnectionPhase::HandshakeResponse & packet) override;
size_t packet_size,
char * buf,
size_t pos,
std::function<void(size_t)> read_bytes,
MySQLProtocol::ConnectionPhase::HandshakeResponse & packet) override;
RSA & public_key;
RSA & private_key;

View File

@ -0,0 +1,38 @@
ATTACH VIEW key_column_usage
(
`referenced_table_schema` Nullable(String),
`referenced_table_name` Nullable(String),
`referenced_column_name` Nullable(String),
`table_schema` String,
`table_name` String,
`column_name` Nullable(String),
`ordinal_position` UInt32,
`constraint_name` Nullable(String),
`REFERENCED_TABLE_SCHEMA` Nullable(String),
`REFERENCED_TABLE_NAME` Nullable(String),
`REFERENCED_COLUMN_NAME` Nullable(String),
`TABLE_SCHEMA` String,
`TABLE_NAME` String,
`COLUMN_NAME` Nullable(String),
`ORDINAL_POSITION` UInt32,
`CONSTRAINT_NAME` Nullable(String)
) AS
SELECT NULL AS `referenced_table_schema`,
NULL AS `referenced_table_name`,
NULL AS `referenced_column_name`,
database AS `table_schema`,
table AS `table_name`,
name AS `column_name`,
position AS `ordinal_position`,
'PRIMARY' AS `constraint_name`,
`referenced_table_schema` AS `REFERENCED_TABLE_SCHEMA`,
`referenced_table_name` AS `REFERENCED_TABLE_NAME`,
`referenced_column_name` AS `REFERENCED_COLUMN_NAME`,
`table_schema` AS `TABLE_SCHEMA`,
`table_name` AS `TABLE_NAME`,
`column_name` AS `COLUMN_NAME`,
`ordinal_position` AS `ORDINAL_POSITION`,
`constraint_name` AS `CONSTRAINT_NAME`
FROM system.columns
WHERE is_in_primary_key;

View File

@ -0,0 +1,25 @@
ATTACH VIEW referential_constraints
(
`constraint_name` Nullable(String),
`constraint_schema` String,
`table_name` String,
`update_rule` String,
`delete_rule` String,
`CONSTRAINT_NAME` Nullable(String),
`CONSTRAINT_SCHEMA` String,
`TABLE_NAME` String,
`UPDATE_RULE` String,
`DELETE_RULE` String
) AS
SELECT NULL AS `constraint_name`,
'' AS `constraint_schema`,
'' AS `table_name`,
'' AS `update_rule`,
'' AS `delete_rule`,
NULL AS `CONSTRAINT_NAME`,
'' AS `CONSTRAINT_SCHEMA`,
'' AS `TABLE_NAME`,
'' AS `UPDATE_RULE`,
'' AS `DELETE_RULE`
WHERE false;

View File

@ -1,26 +1,33 @@
ATTACH VIEW schemata
(
`catalog_name` String,
`schema_name` String,
`schema_owner` String,
`default_character_set_catalog` Nullable(String),
`default_character_set_schema` Nullable(String),
`default_character_set_name` Nullable(String),
`sql_path` Nullable(String),
`CATALOG_NAME` String ALIAS catalog_name,
`SCHEMA_NAME` String ALIAS schema_name,
`SCHEMA_OWNER` String ALIAS schema_owner,
`DEFAULT_CHARACTER_SET_CATALOG` Nullable(String) ALIAS default_character_set_catalog,
`DEFAULT_CHARACTER_SET_SCHEMA` Nullable(String) ALIAS default_character_set_schema,
`DEFAULT_CHARACTER_SET_NAME` Nullable(String) ALIAS default_character_set_name,
`SQL_PATH` Nullable(String) ALIAS sql_path
) AS
SELECT
name AS catalog_name,
name AS schema_name,
'default' AS schema_owner,
NULL AS default_character_set_catalog,
NULL AS default_character_set_schema,
NULL AS default_character_set_name,
NULL AS sql_path
(
`catalog_name` String,
`schema_name` String,
`schema_owner` String,
`default_character_set_catalog` Nullable(String),
`default_character_set_schema` Nullable(String),
`default_character_set_name` Nullable(String),
`sql_path` Nullable(String),
`CATALOG_NAME` String,
`SCHEMA_NAME` String,
`SCHEMA_OWNER` String,
`DEFAULT_CHARACTER_SET_CATALOG` Nullable(String),
`DEFAULT_CHARACTER_SET_SCHEMA` Nullable(String),
`DEFAULT_CHARACTER_SET_NAME` Nullable(String),
`SQL_PATH` Nullable(String)
) AS
SELECT name AS `catalog_name`,
name AS `schema_name`,
'default' AS `schema_owner`,
NULL AS `default_character_set_catalog`,
NULL AS `default_character_set_schema`,
NULL AS `default_character_set_name`,
NULL AS `sql_path`,
catalog_name AS `CATALOG_NAME`,
schema_name AS `SCHEMA_NAME`,
schema_owner AS `SCHEMA_OWNER`,
NULL AS `DEFAULT_CHARACTER_SET_CATALOG`,
NULL AS `DEFAULT_CHARACTER_SET_SCHEMA`,
NULL AS `DEFAULT_CHARACTER_SET_NAME`,
NULL AS `SQL_PATH`
FROM system.databases

View File

@ -1,17 +1,35 @@
ATTACH VIEW tables
(
`table_catalog` String,
`table_schema` String,
`table_name` String,
`table_type` Enum8('BASE TABLE' = 1, 'VIEW' = 2, 'FOREIGN TABLE' = 3, 'LOCAL TEMPORARY' = 4, 'SYSTEM VIEW' = 5),
`TABLE_CATALOG` String ALIAS table_catalog,
`TABLE_SCHEMA` String ALIAS table_schema,
`TABLE_NAME` String ALIAS table_name,
`TABLE_TYPE` Enum8('BASE TABLE' = 1, 'VIEW' = 2, 'FOREIGN TABLE' = 3, 'LOCAL TEMPORARY' = 4, 'SYSTEM VIEW' = 5) ALIAS table_type
) AS
SELECT
database AS table_catalog,
database AS table_schema,
name AS table_name,
multiIf(is_temporary, 4, engine like '%View', 2, engine LIKE 'System%', 5, has_own_data = 0, 3, 1) AS table_type
FROM system.tables
(
`table_catalog` String,
`table_schema` String,
`table_name` String,
`table_type` String,
`table_comment` String,
`table_collation` String,
`TABLE_CATALOG` String,
`TABLE_SCHEMA` String,
`TABLE_NAME` String,
`TABLE_TYPE` String,
`TABLE_COMMENT` String,
`TABLE_COLLATION` String
) AS
SELECT database AS `table_catalog`,
database AS `table_schema`,
name AS `table_name`,
comment AS `table_comment`,
multiIf(
is_temporary, 'LOCAL TEMPORARY',
engine LIKE '%View', 'VIEW',
engine LIKE 'System%', 'SYSTEM VIEW',
has_own_data = 0, 'FOREIGN TABLE',
'BASE TABLE'
) AS `table_type`,
'utf8mb4_0900_ai_ci' AS `table_collation`,
table_catalog AS `TABLE_CATALOG`,
table_schema AS `TABLE_SCHEMA`,
table_name AS `TABLE_NAME`,
table_comment AS `TABLE_COMMENT`,
table_type AS `TABLE_TYPE`,
table_collation AS `TABLE_COLLATION`
FROM system.tables

View File

@ -12,7 +12,8 @@ INCBIN(resource_schemata_sql, SOURCE_DIR "/src/Storages/System/InformationSchema
INCBIN(resource_tables_sql, SOURCE_DIR "/src/Storages/System/InformationSchema/tables.sql");
INCBIN(resource_views_sql, SOURCE_DIR "/src/Storages/System/InformationSchema/views.sql");
INCBIN(resource_columns_sql, SOURCE_DIR "/src/Storages/System/InformationSchema/columns.sql");
INCBIN(resource_key_column_usage_sql, SOURCE_DIR "/src/Storages/System/InformationSchema/key_column_usage.sql");
INCBIN(resource_referential_constraints_sql, SOURCE_DIR "/src/Storages/System/InformationSchema/referential_constraints.sql");
namespace DB
{
@ -66,6 +67,8 @@ void attachInformationSchema(ContextMutablePtr context, IDatabase & information_
createInformationSchemaView(context, information_schema_database, "tables", std::string_view(reinterpret_cast<const char *>(gresource_tables_sqlData), gresource_tables_sqlSize));
createInformationSchemaView(context, information_schema_database, "views", std::string_view(reinterpret_cast<const char *>(gresource_views_sqlData), gresource_views_sqlSize));
createInformationSchemaView(context, information_schema_database, "columns", std::string_view(reinterpret_cast<const char *>(gresource_columns_sqlData), gresource_columns_sqlSize));
createInformationSchemaView(context, information_schema_database, "key_column_usage", std::string_view(reinterpret_cast<const char *>(gresource_key_column_usage_sqlData), gresource_key_column_usage_sqlSize));
createInformationSchemaView(context, information_schema_database, "referential_constraints", std::string_view(reinterpret_cast<const char *>(gresource_referential_constraints_sqlData), gresource_referential_constraints_sqlSize));
}
}