diff --git a/src/Core/MySQL/PacketsGeneric.cpp b/src/Core/MySQL/PacketsGeneric.cpp index af80797d5c1..88183890399 100644 --- a/src/Core/MySQL/PacketsGeneric.cpp +++ b/src/Core/MySQL/PacketsGeneric.cpp @@ -8,254 +8,263 @@ namespace DB namespace MySQLProtocol { -namespace Generic -{ - -static const size_t MYSQL_ERRMSG_SIZE = 512; - -void SSLRequest::readPayloadImpl(ReadBuffer & buf) -{ - buf.readStrict(reinterpret_cast(&capability_flags), 4); - buf.readStrict(reinterpret_cast(&max_packet_size), 4); - buf.readStrict(reinterpret_cast(&character_set), 1); -} - -OKPacket::OKPacket(uint32_t capabilities_) - : header(0x00), capabilities(capabilities_), affected_rows(0x00), last_insert_id(0x00), status_flags(0x00) -{ -} - -OKPacket::OKPacket( - uint8_t header_, uint32_t capabilities_, uint64_t affected_rows_, uint32_t status_flags_, int16_t warnings_, - String session_state_changes_, String info_) - : header(header_), capabilities(capabilities_), affected_rows(affected_rows_), last_insert_id(0), warnings(warnings_), - status_flags(status_flags_), session_state_changes(std::move(session_state_changes_)), info(std::move(info_)) -{ -} - -size_t OKPacket::getPayloadSize() const -{ - size_t result = 2 + getLengthEncodedNumberSize(affected_rows); - - if (capabilities & CLIENT_PROTOCOL_41) + namespace Generic { - result += 4; - } - else if (capabilities & CLIENT_TRANSACTIONS) - { - result += 2; - } - if (capabilities & CLIENT_SESSION_TRACK) - { - result += getLengthEncodedStringSize(info); - if (status_flags & SERVER_SESSION_STATE_CHANGED) - result += getLengthEncodedStringSize(session_state_changes); - } - else - { - result += info.size(); - } + static const size_t MYSQL_ERRMSG_SIZE = 512; - return result; -} - -void OKPacket::readPayloadImpl(ReadBuffer & payload) - -{ - payload.readStrict(reinterpret_cast(&header), 1); - affected_rows = readLengthEncodedNumber(payload); - last_insert_id = readLengthEncodedNumber(payload); - - if (capabilities & CLIENT_PROTOCOL_41) - { - payload.readStrict(reinterpret_cast(&status_flags), 2); - payload.readStrict(reinterpret_cast(&warnings), 2); - } - else if (capabilities & CLIENT_TRANSACTIONS) - { - payload.readStrict(reinterpret_cast(&status_flags), 2); - } - - if (capabilities & CLIENT_SESSION_TRACK) - { - readLengthEncodedString(info, payload); - if (status_flags & SERVER_SESSION_STATE_CHANGED) + void SSLRequest::readPayloadImpl(ReadBuffer & buf) { - readLengthEncodedString(session_state_changes, payload); + buf.readStrict(reinterpret_cast(&capability_flags), 4); + buf.readStrict(reinterpret_cast(&max_packet_size), 4); + buf.readStrict(reinterpret_cast(&character_set), 1); } - } - else - { - readString(info, payload); - } -} -void OKPacket::writePayloadImpl(WriteBuffer & buffer) const + OKPacket::OKPacket(uint32_t capabilities_) + : header(0x00), capabilities(capabilities_), affected_rows(0x00), last_insert_id(0x00), status_flags(0x00) + { + } -{ - buffer.write(header); - writeLengthEncodedNumber(affected_rows, buffer); - writeLengthEncodedNumber(last_insert_id, buffer); /// last insert-id + OKPacket::OKPacket( + uint8_t header_, + uint32_t capabilities_, + uint64_t affected_rows_, + uint32_t status_flags_, + int16_t warnings_, + String session_state_changes_, + String info_) + : header(header_) + , capabilities(capabilities_) + , affected_rows(affected_rows_) + , last_insert_id(0) + , warnings(warnings_) + , status_flags(status_flags_) + , session_state_changes(std::move(session_state_changes_)) + , info(std::move(info_)) + { + } - if (capabilities & CLIENT_PROTOCOL_41) - { - buffer.write(reinterpret_cast(&status_flags), 2); - buffer.write(reinterpret_cast(&warnings), 2); - } - else if (capabilities & CLIENT_TRANSACTIONS) - { - buffer.write(reinterpret_cast(&status_flags), 2); - } + size_t OKPacket::getPayloadSize() const + { + size_t result = 2 + getLengthEncodedNumberSize(affected_rows); - if (capabilities & CLIENT_SESSION_TRACK) - { - writeLengthEncodedString(info, buffer); - if (status_flags & SERVER_SESSION_STATE_CHANGED) - writeLengthEncodedString(session_state_changes, buffer); - } - else - { - writeString(info, buffer); - } -} - -EOFPacket::EOFPacket() : warnings(0x00), status_flags(0x00) -{ -} - -EOFPacket::EOFPacket(int warnings_, int status_flags_) - : warnings(warnings_), status_flags(status_flags_) -{ -} - -size_t EOFPacket::getPayloadSize() const -{ - return 5; -} - -void EOFPacket::readPayloadImpl(ReadBuffer & payload) -{ - payload.readStrict(reinterpret_cast(&header), 1); - assert(header == 0xfe); - payload.readStrict(reinterpret_cast(&warnings), 2); - payload.readStrict(reinterpret_cast(&status_flags), 2); -} - -void EOFPacket::writePayloadImpl(WriteBuffer & buffer) const -{ - buffer.write(header); // EOF header - buffer.write(reinterpret_cast(&warnings), 2); - buffer.write(reinterpret_cast(&status_flags), 2); -} - -void AuthSwitchPacket::readPayloadImpl(ReadBuffer & payload) -{ - payload.readStrict(reinterpret_cast(&header), 1); - assert(header == 0xfe); - readStringUntilEOF(plugin_name, payload); -} - -ERRPacket::ERRPacket() : error_code(0x00) -{ -} - -ERRPacket::ERRPacket(int error_code_, String sql_state_, String error_message_) - : error_code(error_code_), sql_state(std::move(sql_state_)), error_message(std::move(error_message_)) -{ -} - -size_t ERRPacket::getPayloadSize() const -{ - return 4 + sql_state.length() + std::min(error_message.length(), MYSQL_ERRMSG_SIZE); -} - -void ERRPacket::readPayloadImpl(ReadBuffer & payload) -{ - payload.readStrict(reinterpret_cast(&header), 1); - assert(header == 0xff); - - payload.readStrict(reinterpret_cast(&error_code), 2); - - /// SQL State [optional: # + 5bytes string] - UInt8 sharp = static_cast(*payload.position()); - if (sharp == 0x23) - { - payload.ignore(1); - sql_state.resize(5); - payload.readStrict(reinterpret_cast(sql_state.data()), 5); - } - readString(error_message, payload); -} - -void ERRPacket::writePayloadImpl(WriteBuffer & buffer) const -{ - buffer.write(header); - buffer.write(reinterpret_cast(&error_code), 2); - buffer.write('#'); - buffer.write(sql_state.data(), sql_state.length()); - buffer.write(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE)); -} - -ResponsePacket::ResponsePacket(UInt32 server_capability_flags_) - : ok(OKPacket(server_capability_flags_)) -{ -} - -ResponsePacket::ResponsePacket(UInt32 server_capability_flags_, bool is_handshake_) - : ok(OKPacket(server_capability_flags_)), is_handshake(is_handshake_) -{ -} - -void ResponsePacket::readPayloadImpl(ReadBuffer & payload) -{ - UInt16 header = static_cast(*payload.position()); - switch (header) - { - case PACKET_OK: - packetType = PACKET_OK; - ok.readPayloadWithUnpacked(payload); - break; - case PACKET_ERR: - packetType = PACKET_ERR; - err.readPayloadWithUnpacked(payload); - break; - case PACKET_EOF: - if (is_handshake) + if (capabilities & CLIENT_PROTOCOL_41) { - packetType = PACKET_AUTH_SWITCH; - auth_switch.readPayloadWithUnpacked(payload); + result += 4; + } + else if (capabilities & CLIENT_TRANSACTIONS) + { + result += 2; + } + + if (capabilities & CLIENT_SESSION_TRACK) + { + result += getLengthEncodedStringSize(info); + if (status_flags & SERVER_SESSION_STATE_CHANGED) + result += getLengthEncodedStringSize(session_state_changes); } else { - packetType = PACKET_EOF; - eof.readPayloadWithUnpacked(payload); + result += info.size(); } - break; - case PACKET_LOCALINFILE: - packetType = PACKET_LOCALINFILE; - break; - default: - packetType = PACKET_OK; - column_length = readLengthEncodedNumber(payload); + + return result; + } + + void OKPacket::readPayloadImpl(ReadBuffer & payload) + + { + payload.readStrict(reinterpret_cast(&header), 1); + affected_rows = readLengthEncodedNumber(payload); + last_insert_id = readLengthEncodedNumber(payload); + + if (capabilities & CLIENT_PROTOCOL_41) + { + payload.readStrict(reinterpret_cast(&status_flags), 2); + payload.readStrict(reinterpret_cast(&warnings), 2); + } + else if (capabilities & CLIENT_TRANSACTIONS) + { + payload.readStrict(reinterpret_cast(&status_flags), 2); + } + + if (capabilities & CLIENT_SESSION_TRACK) + { + readLengthEncodedString(info, payload); + if (status_flags & SERVER_SESSION_STATE_CHANGED) + { + readLengthEncodedString(session_state_changes, payload); + } + } + else + { + readString(info, payload); + } + } + + void OKPacket::writePayloadImpl(WriteBuffer & buffer) const + + { + buffer.write(header); + writeLengthEncodedNumber(affected_rows, buffer); + writeLengthEncodedNumber(last_insert_id, buffer); /// last insert-id + + if (capabilities & CLIENT_PROTOCOL_41) + { + buffer.write(reinterpret_cast(&status_flags), 2); + buffer.write(reinterpret_cast(&warnings), 2); + } + else if (capabilities & CLIENT_TRANSACTIONS) + { + buffer.write(reinterpret_cast(&status_flags), 2); + } + + if (capabilities & CLIENT_SESSION_TRACK) + { + writeLengthEncodedString(info, buffer); + if (status_flags & SERVER_SESSION_STATE_CHANGED) + writeLengthEncodedString(session_state_changes, buffer); + } + else + { + writeString(info, buffer); + } + } + + EOFPacket::EOFPacket() : warnings(0x00), status_flags(0x00) + { + } + + EOFPacket::EOFPacket(int warnings_, int status_flags_) : warnings(warnings_), status_flags(status_flags_) + { + } + + size_t EOFPacket::getPayloadSize() const + { + return 5; + } + + void EOFPacket::readPayloadImpl(ReadBuffer & payload) + { + payload.readStrict(reinterpret_cast(&header), 1); + assert(header == 0xfe); + payload.readStrict(reinterpret_cast(&warnings), 2); + payload.readStrict(reinterpret_cast(&status_flags), 2); + } + + void EOFPacket::writePayloadImpl(WriteBuffer & buffer) const + { + buffer.write(header); // EOF header + buffer.write(reinterpret_cast(&warnings), 2); + buffer.write(reinterpret_cast(&status_flags), 2); + } + + void AuthSwitchPacket::readPayloadImpl(ReadBuffer & payload) + { + payload.readStrict(reinterpret_cast(&header), 1); + assert(header == 0xfe); + readStringUntilEOF(plugin_name, payload); + } + + ERRPacket::ERRPacket() : error_code(0x00) + { + } + + ERRPacket::ERRPacket(int error_code_, String sql_state_, String error_message_) + : error_code(error_code_), sql_state(std::move(sql_state_)), error_message(std::move(error_message_)) + { + } + + size_t ERRPacket::getPayloadSize() const + { + return 4 + sql_state.length() + std::min(error_message.length(), MYSQL_ERRMSG_SIZE); + } + + void ERRPacket::readPayloadImpl(ReadBuffer & payload) + { + payload.readStrict(reinterpret_cast(&header), 1); + assert(header == 0xff); + + payload.readStrict(reinterpret_cast(&error_code), 2); + + /// SQL State [optional: # + 5bytes string] + UInt8 sharp = static_cast(*payload.position()); + if (sharp == 0x23) + { + payload.ignore(1); + sql_state.resize(5); + payload.readStrict(reinterpret_cast(sql_state.data()), 5); + } + readString(error_message, payload); + } + + void ERRPacket::writePayloadImpl(WriteBuffer & buffer) const + { + buffer.write(header); + buffer.write(reinterpret_cast(&error_code), 2); + buffer.write('#'); + buffer.write(sql_state.data(), sql_state.length()); + buffer.write(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE)); + } + + ResponsePacket::ResponsePacket(UInt32 server_capability_flags_) : ok(OKPacket(server_capability_flags_)) + { + } + + ResponsePacket::ResponsePacket(UInt32 server_capability_flags_, bool is_handshake_) + : ok(OKPacket(server_capability_flags_)), is_handshake(is_handshake_) + { + } + + void ResponsePacket::readPayloadImpl(ReadBuffer & payload) + { + UInt16 header = static_cast(*payload.position()); + switch (header) + { + case PACKET_OK: + packetType = PACKET_OK; + ok.readPayloadWithUnpacked(payload); + break; + case PACKET_ERR: + packetType = PACKET_ERR; + err.readPayloadWithUnpacked(payload); + break; + case PACKET_EOF: + if (is_handshake) + { + packetType = PACKET_AUTH_SWITCH; + auth_switch.readPayloadWithUnpacked(payload); + } + else + { + packetType = PACKET_EOF; + eof.readPayloadWithUnpacked(payload); + } + break; + case PACKET_LOCALINFILE: + packetType = PACKET_LOCALINFILE; + break; + default: + packetType = PACKET_OK; + column_length = readLengthEncodedNumber(payload); + } + } + + LengthEncodedNumber::LengthEncodedNumber(uint64_t value_) : value(value_) + { + } + + size_t LengthEncodedNumber::getPayloadSize() const + { + return getLengthEncodedNumberSize(value); + } + + void LengthEncodedNumber::writePayloadImpl(WriteBuffer & buffer) const + { + writeLengthEncodedNumber(value, buffer); + } + } -} - -LengthEncodedNumber::LengthEncodedNumber(uint64_t value_) : value(value_) -{ -} - -size_t LengthEncodedNumber::getPayloadSize() const -{ - return getLengthEncodedNumberSize(value); -} - -void LengthEncodedNumber::writePayloadImpl(WriteBuffer & buffer) const -{ - writeLengthEncodedNumber(value, buffer); -} - -} } diff --git a/src/Core/MySQL/PacketsProtocolBinary.cpp b/src/Core/MySQL/PacketsProtocolBinary.cpp new file mode 100644 index 00000000000..287dda269e6 --- /dev/null +++ b/src/Core/MySQL/PacketsProtocolBinary.cpp @@ -0,0 +1,230 @@ +#include +#include +#include +#include +#include +#include +#include "Columns/ColumnLowCardinality.h" +#include "Columns/ColumnVector.h" +#include "DataTypes/DataTypeLowCardinality.h" +#include "DataTypes/DataTypeNullable.h" +#include "Formats/FormatSettings.h" +#include "IO/WriteBufferFromString.h" +#include "base/types.h" + +namespace DB +{ + +namespace MySQLProtocol +{ + + namespace ProtocolBinary + { + ResultSetRow::ResultSetRow( + const Serializations & serializations_, const DataTypes & data_types_, const Columns & columns_, int row_num_) + : row_num(row_num_), columns(columns_), data_types(data_types_), serializations(serializations_) + { + /// See https://dev.mysql.com/doc/dev/mysql-server/8.1.0/page_protocol_binary_resultset.html#sect_protocol_binary_resultset_row + payload_size = 1 + null_bitmap_size; + // LOG_TRACE(&Poco::Logger::get("ResultSetRow"), "Null bitmap size: {}", null_bitmap_size); + FormatSettings format_settings; + for (size_t i = 0; i < columns.size(); ++i) + { + ColumnPtr col = columns[i]; + + if (col->isNullAt(row_num)) + { + null_bitmap[i / 8] |= 1 << i % 8; + } + + TypeIndex type_index = removeNullable(removeLowCardinality(data_types[i]))->getTypeId(); + switch (type_index) + { + case TypeIndex::Int8: + case TypeIndex::UInt8: + payload_size += 1; + break; + case TypeIndex::Int16: + case TypeIndex::UInt16: + payload_size += 2; + break; + case TypeIndex::Int32: + case TypeIndex::UInt32: + case TypeIndex::Float32: + payload_size += 4; + break; + case TypeIndex::Int64: + case TypeIndex::UInt64: + case TypeIndex::Float64: + payload_size += 8; + break; + case TypeIndex::Date: { + UInt64 value = col->get64(row_num); + if (value == 0) + { + payload_size += 1; // length only, no other fields + } + else + { + payload_size += 5; + } + break; + } + case TypeIndex::DateTime: { + UInt64 value = col->get64(row_num); + if (value == 0) + { + payload_size += 1; // length only, no other fields + } + else + { + Poco::DateTime dt = Poco::DateTime(Poco::Timestamp(value * 1000 * 1000)); + if (dt.second() == 0 && dt.minute() == 0 && dt.hour() == 0) + { + payload_size += 5; + } + else + { + payload_size += 8; + } + } + break; + } + default: + WriteBufferFromOwnString ostr; + serializations[i]->serializeText(*columns[i], row_num, ostr, format_settings); + payload_size += getLengthEncodedStringSize(ostr.str()); + serialized[i] = std::move(ostr.str()); + break; + } + } + } + + void ResultSetRow::writePayloadImpl(WriteBuffer & buffer) const + { + buffer.write(static_cast(0x00)); + buffer.write(null_bitmap.data(), null_bitmap_size); + for (size_t i = 0; i < columns.size(); ++i) + { + ColumnPtr col = columns[i]; + if (col->isNullAt(row_num)) + { + continue; // NULLs are stored in the null bitmap only + } + + TypeIndex type_index = removeNullable(removeLowCardinality(data_types[i]))->getTypeId(); + switch (type_index) + { + case TypeIndex::UInt8: { + UInt64 value = col->get64(row_num); + buffer.write(reinterpret_cast(&value), 1); + break; + } + case TypeIndex::UInt16: { + UInt64 value = col->get64(row_num); + buffer.write(reinterpret_cast(&value), 2); + break; + } + case TypeIndex::UInt32: { + UInt64 value = col->get64(row_num); + buffer.write(reinterpret_cast(&value), 4); + break; + } + case TypeIndex::UInt64: { + UInt64 value = col->get64(row_num); + buffer.write(reinterpret_cast(&value), 8); + break; + } + case TypeIndex::Int8: { + UInt64 value = col->get64(row_num); + buffer.write(reinterpret_cast(&value), 1); + break; + } + case TypeIndex::Int16: { + UInt64 value = col->get64(row_num); + buffer.write(reinterpret_cast(&value), 2); + break; + } + case TypeIndex::Int32: { + UInt64 value = col->get64(row_num); + buffer.write(reinterpret_cast(&value), 4); + break; + } + case TypeIndex::Int64: { + UInt64 value = col->get64(row_num); + buffer.write(reinterpret_cast(&value), 8); + break; + } + case TypeIndex::Float32: { + Float32 value = col->getFloat32(row_num); + buffer.write(reinterpret_cast(&value), 4); + break; + } + case TypeIndex::Float64: { + Float64 value = col->getFloat64(row_num); + buffer.write(reinterpret_cast(&value), 8); + break; + } + case TypeIndex::Date: { + UInt64 value = col->get64(row_num); + if (value != 0) + { + Poco::DateTime dt = Poco::DateTime(Poco::Timestamp(value * 1000 * 1000)); + buffer.write(static_cast(4)); // bytes_following + int year = dt.year(); + int month = dt.month(); + int day = dt.day(); + buffer.write(reinterpret_cast(&year), 2); + buffer.write(reinterpret_cast(&month), 1); + buffer.write(reinterpret_cast(&day), 1); + } + else + { + buffer.write(static_cast(0)); + } + break; + } + case TypeIndex::DateTime: { + UInt64 value = col->get64(row_num); + if (value != 0) + { + Poco::DateTime dt = Poco::DateTime(Poco::Timestamp(value * 1000 * 1000)); + bool is_date_time = !(dt.hour() == 0 && dt.minute() == 0 && dt.second() == 0); + size_t bytes_following = is_date_time ? 7 : 4; + buffer.write(reinterpret_cast(&bytes_following), 1); + int year = dt.year(); + int month = dt.month(); + int day = dt.day(); + buffer.write(reinterpret_cast(&year), 2); + buffer.write(reinterpret_cast(&month), 1); + buffer.write(reinterpret_cast(&day), 1); + if (is_date_time) + { + int hour = dt.hourAMPM(); + int minute = dt.minute(); + int second = dt.second(); + buffer.write(reinterpret_cast(&hour), 1); + buffer.write(reinterpret_cast(&minute), 1); + buffer.write(reinterpret_cast(&second), 1); + } + } + else + { + buffer.write(static_cast(0)); + } + break; + } + default: + writeLengthEncodedString(serialized[i], buffer); + break; + } + } + } + + size_t ResultSetRow::getPayloadSize() const + { + return payload_size; + }; + } +} +} diff --git a/src/Core/MySQL/PacketsProtocolBinary.h b/src/Core/MySQL/PacketsProtocolBinary.h new file mode 100644 index 00000000000..69936e527c1 --- /dev/null +++ b/src/Core/MySQL/PacketsProtocolBinary.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include +#include +#include "DataTypes/IDataType.h" +#include "DataTypes/Serializations/ISerialization.h" + +namespace DB +{ + +namespace MySQLProtocol +{ + + namespace ProtocolBinary + { + class ResultSetRow : public IMySQLWritePacket + { + private: + TypeIndex getTypeIndex(DataTypePtr data_type, const ColumnPtr & col) const; + + protected: + int row_num; + const Columns & columns; + const DataTypes & data_types; + const Serializations & serializations; + + std::vector serialized = std::vector(columns.size()); + + size_t null_bitmap_size = (columns.size() + 7) / 8; + std::vector null_bitmap = std::vector(null_bitmap_size, 0); + + size_t payload_size = 0; + + size_t getPayloadSize() const override; + + void writePayloadImpl(WriteBuffer & buffer) const override; + + public: + ResultSetRow(const Serializations & serializations_, const DataTypes & data_types_, const Columns & columns_, int row_num_); + }; + } +} +} diff --git a/src/Core/MySQL/PacketsProtocolText.cpp b/src/Core/MySQL/PacketsProtocolText.cpp index 728e8061e87..9c5bf6b6e05 100644 --- a/src/Core/MySQL/PacketsProtocolText.cpp +++ b/src/Core/MySQL/PacketsProtocolText.cpp @@ -1,7 +1,8 @@ #include -#include #include +#include #include +#include "Core/MySQL/IMySQLWritePacket.h" namespace DB { @@ -9,197 +10,212 @@ namespace DB namespace MySQLProtocol { -namespace ProtocolText -{ - -ResultSetRow::ResultSetRow(const Serializations & serializations, const Columns & columns_, int row_num_) - : columns(columns_), row_num(row_num_) -{ - for (size_t i = 0; i < columns.size(); ++i) + namespace ProtocolText { - if (columns[i]->isNullAt(row_num)) + + ResultSetRow::ResultSetRow(const Serializations & serializations, const Columns & columns_, int row_num_) + : columns(columns_), row_num(row_num_) { - payload_size += 1; - serialized.emplace_back("\xfb"); + for (size_t i = 0; i < columns.size(); ++i) + { + if (columns[i]->isNullAt(row_num)) + { + payload_size += 1; + serialized.emplace_back("\xfb"); + } + else + { + WriteBufferFromOwnString ostr; + serializations[i]->serializeText(*columns[i], row_num, ostr, FormatSettings()); + payload_size += getLengthEncodedStringSize(ostr.str()); + serialized.push_back(std::move(ostr.str())); + } + } } - else + + size_t ResultSetRow::getPayloadSize() const { - WriteBufferFromOwnString ostr; - serializations[i]->serializeText(*columns[i], row_num, ostr, FormatSettings()); - payload_size += getLengthEncodedStringSize(ostr.str()); - serialized.push_back(std::move(ostr.str())); + return payload_size; } + + void ResultSetRow::writePayloadImpl(WriteBuffer & buffer) const + { + for (size_t i = 0; i < columns.size(); ++i) + { + if (columns[i]->isNullAt(row_num)) + buffer.write(serialized[i].data(), 1); + else + writeLengthEncodedString(serialized[i], buffer); + } + } + + void ComFieldList::readPayloadImpl(ReadBuffer & payload) + { + // Command byte has been already read from payload. + readNullTerminated(table, payload); + readStringUntilEOF(field_wildcard, payload); + } + + ColumnDefinition::ColumnDefinition() : character_set(0x00), column_length(0), column_type(MYSQL_TYPE_DECIMAL), flags(0x00) + { + } + + ColumnDefinition::ColumnDefinition( + String schema_, + String table_, + String org_table_, + String name_, + String org_name_, + uint16_t character_set_, + uint32_t column_length_, + ColumnType column_type_, + uint16_t flags_, + uint8_t decimals_, + bool with_defaults_) + : schema(std::move(schema_)) + , table(std::move(table_)) + , org_table(std::move(org_table_)) + , name(std::move(name_)) + , org_name(std::move(org_name_)) + , character_set(character_set_) + , column_length(column_length_) + , column_type(column_type_) + , flags(flags_) + , decimals(decimals_) + , is_comm_field_list_response(with_defaults_) + { + } + + ColumnDefinition::ColumnDefinition( + String name_, uint16_t character_set_, uint32_t column_length_, ColumnType column_type_, uint16_t flags_, uint8_t decimals_) + : ColumnDefinition("", "", "", std::move(name_), "", character_set_, column_length_, column_type_, flags_, decimals_) + { + } + + size_t ColumnDefinition::getPayloadSize() const + { + return 12 + getLengthEncodedStringSize("def") + getLengthEncodedStringSize(schema) + getLengthEncodedStringSize(table) + + getLengthEncodedStringSize(org_table) + getLengthEncodedStringSize(name) + getLengthEncodedStringSize(org_name) + + getLengthEncodedNumberSize(next_length) + is_comm_field_list_response; + } + + void ColumnDefinition::readPayloadImpl(ReadBuffer & payload) + { + String def; + readLengthEncodedString(def, payload); + assert(def == "def"); + readLengthEncodedString(schema, payload); + readLengthEncodedString(table, payload); + readLengthEncodedString(org_table, payload); + readLengthEncodedString(name, payload); + readLengthEncodedString(org_name, payload); + next_length = readLengthEncodedNumber(payload); + payload.readStrict(reinterpret_cast(&character_set), 2); + payload.readStrict(reinterpret_cast(&column_length), 4); + payload.readStrict(reinterpret_cast(&column_type), 1); + payload.readStrict(reinterpret_cast(&flags), 2); + payload.readStrict(reinterpret_cast(&decimals), 1); + payload.ignore(2); + } + + void ColumnDefinition::writePayloadImpl(WriteBuffer & buffer) const + { + writeLengthEncodedString(std::string("def"), buffer); /// always "def" + writeLengthEncodedString(schema, buffer); + writeLengthEncodedString(table, buffer); + writeLengthEncodedString(org_table, buffer); + writeLengthEncodedString(name, buffer); + writeLengthEncodedString(org_name, buffer); + writeLengthEncodedNumber(next_length, buffer); + buffer.write(reinterpret_cast(&character_set), 2); + buffer.write(reinterpret_cast(&column_length), 4); + buffer.write(reinterpret_cast(&column_type), 1); + buffer.write(reinterpret_cast(&flags), 2); + buffer.write(reinterpret_cast(&decimals), 1); + writeChar(0x0, 2, buffer); + if (is_comm_field_list_response) + { + /// We should write length encoded int with string size + /// followed by string with some "default values" (possibly it's column defaults). + /// But we just send NULL for simplicity. + writeChar(0xfb, buffer); + } + } + + ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex type_index) + { + ColumnType column_type; + CharacterSet charset = CharacterSet::binary; + int flags = 0; + uint8_t decimals = 0; + switch (type_index) + { + case TypeIndex::UInt8: + column_type = ColumnType::MYSQL_TYPE_TINY; + flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; + break; + case TypeIndex::UInt16: + column_type = ColumnType::MYSQL_TYPE_SHORT; + flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; + break; + case TypeIndex::UInt32: + column_type = ColumnType::MYSQL_TYPE_LONG; + flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; + break; + case TypeIndex::UInt64: + column_type = ColumnType::MYSQL_TYPE_LONGLONG; + flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; + break; + case TypeIndex::Int8: + column_type = ColumnType::MYSQL_TYPE_TINY; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + case TypeIndex::Int16: + column_type = ColumnType::MYSQL_TYPE_SHORT; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + case TypeIndex::Int32: + column_type = ColumnType::MYSQL_TYPE_LONG; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + case TypeIndex::Int64: + column_type = ColumnType::MYSQL_TYPE_LONGLONG; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + case TypeIndex::Float32: + column_type = ColumnType::MYSQL_TYPE_FLOAT; + flags = ColumnDefinitionFlags::BINARY_FLAG; + decimals = 31; + break; + case TypeIndex::Float64: + column_type = ColumnType::MYSQL_TYPE_DOUBLE; + flags = ColumnDefinitionFlags::BINARY_FLAG; + decimals = 31; + break; + case TypeIndex::Date: + column_type = ColumnType::MYSQL_TYPE_DATE; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + case TypeIndex::DateTime: + column_type = ColumnType::MYSQL_TYPE_DATETIME; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + case TypeIndex::Decimal32: + case TypeIndex::Decimal64: + case TypeIndex::Decimal128: + /// MySQL Decimal has max 65 precision and 30 scale. Thus, Decimal256 is reported as a string + column_type = ColumnType::MYSQL_TYPE_DECIMAL; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + default: + column_type = ColumnType::MYSQL_TYPE_STRING; + charset = CharacterSet::utf8_general_ci; + break; + } + return ColumnDefinition(column_name, charset, 0, column_type, flags, decimals); + } + } -} - -size_t ResultSetRow::getPayloadSize() const -{ - return payload_size; -} - -void ResultSetRow::writePayloadImpl(WriteBuffer & buffer) const -{ - for (size_t i = 0; i < columns.size(); ++i) - { - if (columns[i]->isNullAt(row_num)) - buffer.write(serialized[i].data(), 1); - else - writeLengthEncodedString(serialized[i], buffer); - } -} - -void ComFieldList::readPayloadImpl(ReadBuffer & payload) -{ - // Command byte has been already read from payload. - readNullTerminated(table, payload); - readStringUntilEOF(field_wildcard, payload); -} - -ColumnDefinition::ColumnDefinition() - : character_set(0x00), column_length(0), column_type(MYSQL_TYPE_DECIMAL), flags(0x00) -{ -} - -ColumnDefinition::ColumnDefinition( - String schema_, String table_, String org_table_, String name_, String org_name_, uint16_t character_set_, uint32_t column_length_, - ColumnType column_type_, uint16_t flags_, uint8_t decimals_, bool with_defaults_) - : schema(std::move(schema_)), table(std::move(table_)), org_table(std::move(org_table_)), name(std::move(name_)), - org_name(std::move(org_name_)), character_set(character_set_), column_length(column_length_), column_type(column_type_), - flags(flags_), decimals(decimals_), is_comm_field_list_response(with_defaults_) -{ -} - -ColumnDefinition::ColumnDefinition( - String name_, uint16_t character_set_, uint32_t column_length_, ColumnType column_type_, uint16_t flags_, uint8_t decimals_) - : ColumnDefinition("", "", "", std::move(name_), "", character_set_, column_length_, column_type_, flags_, decimals_) -{ -} - -size_t ColumnDefinition::getPayloadSize() const -{ - return 12 + - getLengthEncodedStringSize("def") + - getLengthEncodedStringSize(schema) + - getLengthEncodedStringSize(table) + - getLengthEncodedStringSize(org_table) + - getLengthEncodedStringSize(name) + - getLengthEncodedStringSize(org_name) + - getLengthEncodedNumberSize(next_length) + - is_comm_field_list_response; -} - -void ColumnDefinition::readPayloadImpl(ReadBuffer & payload) -{ - String def; - readLengthEncodedString(def, payload); - assert(def == "def"); - readLengthEncodedString(schema, payload); - readLengthEncodedString(table, payload); - readLengthEncodedString(org_table, payload); - readLengthEncodedString(name, payload); - readLengthEncodedString(org_name, payload); - next_length = readLengthEncodedNumber(payload); - payload.readStrict(reinterpret_cast(&character_set), 2); - payload.readStrict(reinterpret_cast(&column_length), 4); - payload.readStrict(reinterpret_cast(&column_type), 1); - payload.readStrict(reinterpret_cast(&flags), 2); - payload.readStrict(reinterpret_cast(&decimals), 1); - payload.ignore(2); -} - -void ColumnDefinition::writePayloadImpl(WriteBuffer & buffer) const -{ - writeLengthEncodedString(std::string("def"), buffer); /// always "def" - writeLengthEncodedString(schema, buffer); - writeLengthEncodedString(table, buffer); - writeLengthEncodedString(org_table, buffer); - writeLengthEncodedString(name, buffer); - writeLengthEncodedString(org_name, buffer); - writeLengthEncodedNumber(next_length, buffer); - buffer.write(reinterpret_cast(&character_set), 2); - buffer.write(reinterpret_cast(&column_length), 4); - buffer.write(reinterpret_cast(&column_type), 1); - buffer.write(reinterpret_cast(&flags), 2); - buffer.write(reinterpret_cast(&decimals), 1); - writeChar(0x0, 2, buffer); - if (is_comm_field_list_response) - { - /// We should write length encoded int with string size - /// followed by string with some "default values" (possibly it's column defaults). - /// But we just send NULL for simplicity. - writeChar(0xfb, buffer); - } -} - -ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex type_index) -{ - ColumnType column_type; - CharacterSet charset = CharacterSet::binary; - int flags = 0; - switch (type_index) - { - case TypeIndex::UInt8: - column_type = ColumnType::MYSQL_TYPE_TINY; - flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; - break; - case TypeIndex::UInt16: - column_type = ColumnType::MYSQL_TYPE_SHORT; - flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; - break; - case TypeIndex::UInt32: - column_type = ColumnType::MYSQL_TYPE_LONG; - flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; - break; - case TypeIndex::UInt64: - column_type = ColumnType::MYSQL_TYPE_LONGLONG; - flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; - break; - case TypeIndex::Int8: - column_type = ColumnType::MYSQL_TYPE_TINY; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::Int16: - column_type = ColumnType::MYSQL_TYPE_SHORT; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::Int32: - column_type = ColumnType::MYSQL_TYPE_LONG; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::Int64: - column_type = ColumnType::MYSQL_TYPE_LONGLONG; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::Float32: - column_type = ColumnType::MYSQL_TYPE_FLOAT; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::Float64: - column_type = ColumnType::MYSQL_TYPE_DOUBLE; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::Date: - column_type = ColumnType::MYSQL_TYPE_DATE; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::DateTime: - column_type = ColumnType::MYSQL_TYPE_DATETIME; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::String: - case TypeIndex::FixedString: - column_type = ColumnType::MYSQL_TYPE_STRING; - charset = CharacterSet::utf8_general_ci; - break; - default: - column_type = ColumnType::MYSQL_TYPE_STRING; - charset = CharacterSet::utf8_general_ci; - break; - } - return ColumnDefinition(column_name, charset, 0, column_type, flags, 0); -} - -} } diff --git a/src/DataTypes/DataTypesDecimal.cpp b/src/DataTypes/DataTypesDecimal.cpp index fa044d4ac9c..2af216529e5 100644 --- a/src/DataTypes/DataTypesDecimal.cpp +++ b/src/DataTypes/DataTypesDecimal.cpp @@ -1,13 +1,13 @@ #include #include -#include #include #include #include #include #include #include +#include #include @@ -31,6 +31,12 @@ std::string DataTypeDecimal::doGetName() const template std::string DataTypeDecimal::getSQLCompatibleName() const { + /// See https://dev.mysql.com/doc/refman/8.0/en/precision-math-decimal-characteristics.html + /// DECIMAL(M,D) + /// M is the maximum number of digits (the precision). It has a range of 1 to 65. + /// D is the number of digits to the right of the decimal point (the scale). It has a range of 0 to 30 and must be no larger than M. + if (this->precision > 65 || this->scale > 30) + return "TEXT"; return fmt::format("DECIMAL({}, {})", this->precision, this->scale); } @@ -75,14 +81,14 @@ SerializationPtr DataTypeDecimal::doGetDefaultSerialization() const static DataTypePtr create(const ASTPtr & arguments) { if (!arguments || arguments->children.size() != 2) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, - "Decimal data type family must have exactly two arguments: precision and scale"); + throw Exception( + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Decimal data type family must have exactly two arguments: precision and scale"); const auto * precision = arguments->children[0]->as(); const auto * scale = arguments->children[1]->as(); - if (!precision || precision->value.getType() != Field::Types::UInt64 || - !scale || !(scale->value.getType() == Field::Types::Int64 || scale->value.getType() == Field::Types::UInt64)) + if (!precision || precision->value.getType() != Field::Types::UInt64 || !scale + || !(scale->value.getType() == Field::Types::Int64 || scale->value.getType() == Field::Types::UInt64)) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Decimal data type family must have two numbers as its arguments"); UInt64 precision_value = precision->value.get(); @@ -95,13 +101,15 @@ template static DataTypePtr createExact(const ASTPtr & arguments) { if (!arguments || arguments->children.size() != 1) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, - "Decimal32 | Decimal64 | Decimal128 | Decimal256 data type family must have exactly one arguments: scale"); + throw Exception( + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Decimal32 | Decimal64 | Decimal128 | Decimal256 data type family must have exactly one arguments: scale"); const auto * scale_arg = arguments->children[0]->as(); if (!scale_arg || !(scale_arg->value.getType() == Field::Types::Int64 || scale_arg->value.getType() == Field::Types::UInt64)) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Decimal32 | Decimal64 | Decimal128 | Decimal256 data type family must have a one number as its argument"); + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Decimal32 | Decimal64 | Decimal128 | Decimal256 data type family must have a one number as its argument"); UInt64 precision = DecimalUtils::max_precision; UInt64 scale = scale_arg->value.get(); diff --git a/src/Formats/FormatSettings.h b/src/Formats/FormatSettings.h index 2c283dcc2b7..56359fd0bea 100644 --- a/src/Formats/FormatSettings.h +++ b/src/Formats/FormatSettings.h @@ -1,7 +1,7 @@ #pragma once -#include #include +#include #include #include @@ -48,9 +48,9 @@ struct FormatSettings enum class DateTimeInputFormat { - Basic, /// Default format for fast parsing: YYYY-MM-DD hh:mm:ss (ISO-8601 without fractional part and timezone) or NNNNNNNNNN unix timestamp. - BestEffort, /// Use sophisticated rules to parse whatever possible. - BestEffortUS /// Use sophisticated rules to parse American style: mm/dd/yyyy + Basic, /// Default format for fast parsing: YYYY-MM-DD hh:mm:ss (ISO-8601 without fractional part and timezone) or NNNNNNNNNN unix timestamp. + BestEffort, /// Use sophisticated rules to parse whatever possible. + BestEffortUS /// Use sophisticated rules to parse American style: mm/dd/yyyy }; DateTimeInputFormat date_time_input_format = DateTimeInputFormat::Basic; @@ -282,6 +282,14 @@ struct FormatSettings uint32_t client_capabilities = 0; size_t max_packet_size = 0; uint8_t * sequence_id = nullptr; /// Not null if it's MySQLWire output format used to handle MySQL protocol connections. + /** + * COM_QUERY uses Text ResultSet + * https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html + * COM_STMT_EXECUTE uses Binary Protocol ResultSet + * https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute_response.html + * By default, use Text ResultSet. + */ + bool use_binary_result_set = false; } mysql_wire; struct diff --git a/src/Interpreters/InterpreterShowColumnsQuery.cpp b/src/Interpreters/InterpreterShowColumnsQuery.cpp index c86d3c753c4..922f9887a82 100644 --- a/src/Interpreters/InterpreterShowColumnsQuery.cpp +++ b/src/Interpreters/InterpreterShowColumnsQuery.cpp @@ -37,7 +37,7 @@ String InterpreterShowColumnsQuery::getRewrittenQuery() SELECT name AS field, type AS type, - startsWith(type, 'Nullable') AS null, + if(startsWith(type, 'Nullable'), 'YES', 'NO') AS null, trim(concatWithSeparator(' ', if (is_in_primary_key, 'PRI', ''), if (is_in_sorting_key, 'SOR', ''))) AS key, if (default_kind IN ('ALIAS', 'DEFAULT', 'MATERIALIZED'), default_expression, NULL) AS default, '' AS extra )"; diff --git a/src/Processors/Formats/Impl/MySQLOutputFormat.cpp b/src/Processors/Formats/Impl/MySQLOutputFormat.cpp index f2157f63c25..3dafe560281 100644 --- a/src/Processors/Formats/Impl/MySQLOutputFormat.cpp +++ b/src/Processors/Formats/Impl/MySQLOutputFormat.cpp @@ -1,11 +1,12 @@ -#include #include +#include #include #include #include #include #include - +#include +#include "Common/logger_useful.h" namespace DB { @@ -13,17 +14,18 @@ namespace DB using namespace MySQLProtocol; using namespace MySQLProtocol::Generic; using namespace MySQLProtocol::ProtocolText; - +using namespace MySQLProtocol::ProtocolBinary; MySQLOutputFormat::MySQLOutputFormat(WriteBuffer & out_, const Block & header_, const FormatSettings & settings_) - : IOutputFormat(header_, out_) - , client_capabilities(settings_.mysql_wire.client_capabilities) + : IOutputFormat(header_, out_), client_capabilities(settings_.mysql_wire.client_capabilities) { /// MySQlWire is a special format that is usually used as output format for MySQL protocol connections. /// In this case we have a correct `sequence_id` stored in `settings_.mysql_wire`. /// But it's also possible to specify MySQLWire as output format for clickhouse-client or clickhouse-local. /// There is no `sequence_id` stored in `settings_.mysql_wire` in this case, so we create a dummy one. sequence_id = settings_.mysql_wire.sequence_id ? settings_.mysql_wire.sequence_id : &dummy_sequence_id; + /// Switch between Text (COM_QUERY) and Binary (COM_EXECUTE_STMT) ResultSet + use_binary_result_set = settings_.mysql_wire.use_binary_result_set; const auto & header = getPort(PortKind::Main).getHeader(); data_types = header.getDataTypes(); @@ -54,7 +56,7 @@ void MySQLOutputFormat::writePrefix() packet_endpoint->sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId())); } - if (!(client_capabilities & Capability::CLIENT_DEPRECATE_EOF)) + if (!(client_capabilities & Capability::CLIENT_DEPRECATE_EOF) && !use_binary_result_set) { packet_endpoint->sendPacket(EOFPacket(0, 0)); } @@ -63,39 +65,67 @@ void MySQLOutputFormat::writePrefix() void MySQLOutputFormat::consume(Chunk chunk) { - for (size_t i = 0; i < chunk.getNumRows(); ++i) + if (!use_binary_result_set) { - ProtocolText::ResultSetRow row_packet(serializations, chunk.getColumns(), static_cast(i)); - packet_endpoint->sendPacket(row_packet); + for (size_t i = 0; i < chunk.getNumRows(); ++i) + { + ProtocolText::ResultSetRow row_packet(serializations, chunk.getColumns(), static_cast(i)); + packet_endpoint->sendPacket(row_packet); + } + } + else + { + for (size_t i = 0; i < chunk.getNumRows(); ++i) + { + ProtocolBinary::ResultSetRow row_packet(serializations, data_types, chunk.getColumns(), static_cast(i)); + packet_endpoint->sendPacket(row_packet); + } } } void MySQLOutputFormat::finalizeImpl() { - size_t affected_rows = 0; - std::string human_readable_info; - if (QueryStatusPtr process_list_elem = getContext()->getProcessListElement()) + if (!use_binary_result_set) { - CurrentThread::finalizePerformanceCounters(); - QueryStatusInfo info = process_list_elem->getInfo(); - affected_rows = info.written_rows; - double elapsed_seconds = static_cast(info.elapsed_microseconds) / 1000000.0; - human_readable_info = fmt::format( - "Read {} rows, {} in {} sec., {} rows/sec., {}/sec.", - info.read_rows, - ReadableSize(info.read_bytes), - elapsed_seconds, - static_cast(info.read_rows / elapsed_seconds), - ReadableSize(info.read_bytes / elapsed_seconds)); - } + size_t affected_rows = 0; + std::string human_readable_info; + if (QueryStatusPtr process_list_elem = getContext()->getProcessListElement()) + { + CurrentThread::finalizePerformanceCounters(); + QueryStatusInfo info = process_list_elem->getInfo(); + affected_rows = info.written_rows; + double elapsed_seconds = static_cast(info.elapsed_microseconds) / 1000000.0; + human_readable_info = fmt::format( + "Read {} rows, {} in {} sec., {} rows/sec., {}/sec.", + info.read_rows, + ReadableSize(info.read_bytes), + elapsed_seconds, + static_cast(info.read_rows / elapsed_seconds), + ReadableSize(info.read_bytes / elapsed_seconds)); + } - const auto & header = getPort(PortKind::Main).getHeader(); - if (header.columns() == 0) - packet_endpoint->sendPacket(OKPacket(0x0, client_capabilities, affected_rows, 0, 0, "", human_readable_info), true); - else if (client_capabilities & CLIENT_DEPRECATE_EOF) - packet_endpoint->sendPacket(OKPacket(0xfe, client_capabilities, affected_rows, 0, 0, "", human_readable_info), true); + const auto & header = getPort(PortKind::Main).getHeader(); + if (header.columns() == 0) + packet_endpoint->sendPacket(OKPacket(0x0, client_capabilities, affected_rows, 0, 0, "", human_readable_info), true); + else if (client_capabilities & CLIENT_DEPRECATE_EOF) + packet_endpoint->sendPacket(OKPacket(0xfe, client_capabilities, affected_rows, 0, 0, "", human_readable_info), true); + else + packet_endpoint->sendPacket(EOFPacket(0, 0), true); + } else - packet_endpoint->sendPacket(EOFPacket(0, 0), true); + { + size_t affected_rows = 0; + if (QueryStatusPtr process_list_elem = getContext()->getProcessListElement()) + { + CurrentThread::finalizePerformanceCounters(); + QueryStatusInfo info = process_list_elem->getInfo(); + affected_rows = info.written_rows; + } + if (client_capabilities & CLIENT_DEPRECATE_EOF) + packet_endpoint->sendPacket(OKPacket(0xfe, client_capabilities, affected_rows, 0, 0, "", ""), true); + else + packet_endpoint->sendPacket(EOFPacket(0, 0), true); + } } void MySQLOutputFormat::flush() @@ -107,9 +137,8 @@ void registerOutputFormatMySQLWire(FormatFactory & factory) { factory.registerOutputFormat( "MySQLWire", - [](WriteBuffer & buf, - const Block & sample, - const FormatSettings & settings) { return std::make_shared(buf, sample, settings); }); + [](WriteBuffer & buf, const Block & sample, const FormatSettings & settings) + { return std::make_shared(buf, sample, settings); }); } } diff --git a/src/Processors/Formats/Impl/MySQLOutputFormat.h b/src/Processors/Formats/Impl/MySQLOutputFormat.h index 9481ef67070..6161b6bdc14 100644 --- a/src/Processors/Formats/Impl/MySQLOutputFormat.h +++ b/src/Processors/Formats/Impl/MySQLOutputFormat.h @@ -1,7 +1,7 @@ #pragma once -#include #include +#include #include #include @@ -39,6 +39,7 @@ private: MySQLProtocol::PacketEndpointPtr packet_endpoint; DataTypes data_types; Serializations serializations; + bool use_binary_result_set = false; }; } diff --git a/src/Server/MySQLHandler.cpp b/src/Server/MySQLHandler.cpp index 868575b701f..3715dfea9f7 100644 --- a/src/Server/MySQLHandler.cpp +++ b/src/Server/MySQLHandler.cpp @@ -1,29 +1,29 @@ #include "MySQLHandler.h" #include -#include -#include +#include +#include +#include #include #include -#include #include #include -#include -#include -#include #include #include #include +#include #include #include -#include +#include +#include +#include #include #include -#include -#include -#include -#include #include +#include +#include +#include +#include #include "config_version.h" @@ -67,10 +67,7 @@ static String killConnectionIdReplacementQuery(const String & query); static String selectLimitReplacementQuery(const String & query); MySQLHandler::MySQLHandler( - IServer & server_, - TCPServer & tcp_server_, - const Poco::Net::StreamSocket & socket_, - bool ssl_enabled, uint32_t connection_id_) + IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, bool ssl_enabled, uint32_t connection_id_) : Poco::Net::TCPServerConnection(socket_) , server(server_) , tcp_server(tcp_server_) @@ -78,7 +75,8 @@ MySQLHandler::MySQLHandler( , connection_id(connection_id_) , auth_plugin(new MySQLProtocol::Authentication::Native41()) { - server_capabilities = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF; + server_capabilities = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA + | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF; if (ssl_enabled) server_capabilities |= CLIENT_SSL; @@ -104,8 +102,13 @@ void MySQLHandler::run() try { - Handshake handshake(server_capabilities, connection_id, VERSION_STRING + String("-") + VERSION_NAME, - auth_plugin->getName(), auth_plugin->getAuthPluginData(), CharacterSet::utf8_general_ci); + Handshake handshake( + server_capabilities, + connection_id, + VERSION_STRING + String("-") + VERSION_NAME, + auth_plugin->getName(), + auth_plugin->getAuthPluginData(), + CharacterSet::utf8_general_ci); packet_endpoint->sendPacket(handshake, true); LOG_TRACE(log, "Sent handshake"); @@ -115,8 +118,10 @@ void MySQLHandler::run() client_capabilities = handshake_response.capability_flags; max_packet_size = handshake_response.max_packet_size ? handshake_response.max_packet_size : MAX_PACKET_LENGTH; - LOG_TRACE(log, - "Capabilities: {}, max_packet_size: {}, character_set: {}, user: {}, auth_response length: {}, database: {}, auth_plugin_name: {}", + LOG_TRACE( + log, + "Capabilities: {}, max_packet_size: {}, character_set: {}, user: {}, auth_response length: {}, database: {}, auth_plugin_name: " + "{}", handshake_response.capability_flags, handshake_response.max_packet_size, static_cast(handshake_response.character_set), @@ -160,8 +165,8 @@ void MySQLHandler::run() // For commands which are executed without MemoryTracker. LimitReadBuffer limited_payload(payload, 10000, /* trow_exception */ true, /* exact_limit */ {}, "too long MySQL packet."); - LOG_DEBUG(log, "Received command: {}. Connection id: {}.", - static_cast(static_cast(command)), connection_id); + LOG_DEBUG( + log, "Received command: {}. Connection id: {}.", static_cast(static_cast(command)), connection_id); if (!tcp_server.isOpen()) return; @@ -175,7 +180,7 @@ void MySQLHandler::run() comInitDB(limited_payload); break; case COM_QUERY: - comQuery(payload); + comQuery(payload, false); break; case COM_FIELD_LIST: comFieldList(limited_payload); @@ -227,13 +232,15 @@ void MySQLHandler::finishHandshake(MySQLProtocol::ConnectionPhase::HandshakeResp size_t pos = 0; /// Reads at least count and at most packet_size bytes. - auto read_bytes = [this, &buf, &pos, &packet_size](size_t count) -> void { + auto read_bytes = [this, &buf, &pos, &packet_size](size_t count) -> void + { while (pos < count) { int ret = socket().receiveBytes(buf + pos, static_cast(packet_size - pos)); if (ret == 0) { - throw Exception(ErrorCodes::CANNOT_READ_ALL_DATA, "Cannot read all data. Bytes read: {}. Bytes expected: 3", std::to_string(pos)); + throw Exception( + ErrorCodes::CANNOT_READ_ALL_DATA, "Cannot read all data. Bytes read: {}. Bytes expected: 3", std::to_string(pos)); } pos += ret; } @@ -272,7 +279,8 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl authPluginSSL(); } - std::optional auth_response = auth_plugin_name == auth_plugin->getName() ? std::make_optional(initial_auth_response) : std::nullopt; + std::optional auth_response + = auth_plugin_name == auth_plugin->getName() ? std::make_optional(initial_auth_response) : std::nullopt; auth_plugin->authenticate(user_name, *session, auth_response, packet_endpoint, secure_connection, socket().peerAddress()); } catch (const Exception & exc) @@ -304,8 +312,17 @@ void MySQLHandler::comFieldList(ReadBuffer & payload) for (const NameAndTypePair & column : metadata_snapshot->getColumns().getAll()) { ColumnDefinition column_definition( - database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0, true - ); + database, + packet.table, + packet.table, + column.name, + column.name, + CharacterSet::binary, + 100, + ColumnType::MYSQL_TYPE_STRING, + 0, + 0, + true); packet_endpoint->sendPacket(column_definition); } packet_endpoint->sendPacket(OKPacket(0xfe, client_capabilities, 0, 0, 0), true); @@ -318,7 +335,7 @@ void MySQLHandler::comPing() static bool isFederatedServerSetupSetCommand(const String & query); -void MySQLHandler::comQuery(ReadBuffer & payload) +void MySQLHandler::comQuery(ReadBuffer & payload, bool use_binary_protocol_result_set) { String query = String(payload.position(), payload.buffer().end()); @@ -350,20 +367,22 @@ void MySQLHandler::comQuery(ReadBuffer & payload) query_context->setCurrentQueryId(fmt::format("mysql:{}:{}", connection_id, toString(UUIDHelpers::generateV4()))); CurrentThread::QueryScope query_scope{query_context}; - std::atomic affected_rows {0}; + std::atomic affected_rows{0}; auto prev = query_context->getProgressCallback(); - query_context->setProgressCallback([&, my_prev = prev](const Progress & progress) - { - if (my_prev) - my_prev(progress); + query_context->setProgressCallback( + [&, my_prev = prev](const Progress & progress) + { + if (my_prev) + my_prev(progress); - affected_rows += progress.written_rows; - }); + affected_rows += progress.written_rows; + }); FormatSettings format_settings; format_settings.mysql_wire.client_capabilities = client_capabilities; format_settings.mysql_wire.max_packet_size = max_packet_size; format_settings.mysql_wire.sequence_id = &sequence_id; + format_settings.mysql_wire.use_binary_result_set = use_binary_protocol_result_set; auto set_result_details = [&with_output](const QueryResultDetails & details) { @@ -385,11 +404,18 @@ void MySQLHandler::comQuery(ReadBuffer & payload) void MySQLHandler::comStmtPrepare(DB::ReadBuffer & payload) { + if (prepared_statements_map.size() > 10000) /// Shouldn't happen in reality as COM_STMT_CLOSE cleans up the elements + { + LOG_ERROR(log, "Too many prepared statements"); + packet_endpoint->sendPacket(ERRPacket(), true); + return; + } + String query; readStringUntilEOF(query, payload); uint32_t statement_id = current_prepared_statement_id; - if (current_prepared_statement_id == std::numeric_limits::max()) [[unlikely]] + if (current_prepared_statement_id == std::numeric_limits::max()) { current_prepared_statement_id = 0; } @@ -400,7 +426,7 @@ void MySQLHandler::comStmtPrepare(DB::ReadBuffer & payload) // Key collisions should not happen here, as we remove the elements from the map with COM_STMT_CLOSE, // and we have quite a big range of available identifiers with 32-bit unsigned integer - if (prepared_statements_map.contains(statement_id)) [[unlikely]] + if (prepared_statements_map.contains(statement_id)) { LOG_ERROR( log, @@ -411,8 +437,8 @@ void MySQLHandler::comStmtPrepare(DB::ReadBuffer & payload) packet_endpoint->sendPacket(ERRPacket(), true); return; } - prepared_statements_map.emplace(statement_id, query); + prepared_statements_map.emplace(statement_id, query); packet_endpoint->sendPacket(PrepareStatementResponseOK(statement_id, 0, 0, 0), true); } @@ -421,7 +447,7 @@ void MySQLHandler::comStmtExecute(ReadBuffer & payload) uint32_t statement_id; payload.readStrict(reinterpret_cast(&statement_id), 4); - if (!prepared_statements_map.contains(statement_id)) [[unlikely]] + if (!prepared_statements_map.contains(statement_id)) { LOG_ERROR(log, "Could not find prepared statement with id {}", statement_id); packet_endpoint->sendPacket(ERRPacket(), true); @@ -430,14 +456,16 @@ void MySQLHandler::comStmtExecute(ReadBuffer & payload) // Temporary workaround as we work only with queries that do not bind any parameters atm ReadBufferFromString com_query_payload(prepared_statements_map.at(statement_id)); - MySQLHandler::comQuery(com_query_payload); + MySQLHandler::comQuery(com_query_payload, true); }; -void MySQLHandler::comStmtClose([[maybe_unused]] ReadBuffer & payload) { +void MySQLHandler::comStmtClose(ReadBuffer & payload) +{ uint32_t statement_id; payload.readStrict(reinterpret_cast(&statement_id), 4); - if (prepared_statements_map.contains(statement_id)) { + if (prepared_statements_map.contains(statement_id)) + { prepared_statements_map.erase(statement_id); } @@ -447,13 +475,17 @@ void MySQLHandler::comStmtClose([[maybe_unused]] ReadBuffer & payload) { void MySQLHandler::authPluginSSL() { - throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, - "ClickHouse was built without SSL support. Try specifying password using double SHA1 in users.xml."); + throw Exception( + ErrorCodes::SUPPORT_IS_DISABLED, + "ClickHouse was built without SSL support. Try specifying password using double SHA1 in users.xml."); } void MySQLHandler::finishHandshakeSSL( - [[maybe_unused]] size_t packet_size, [[maybe_unused]] char * buf, [[maybe_unused]] size_t pos, - [[maybe_unused]] std::function read_bytes, [[maybe_unused]] MySQLProtocol::ConnectionPhase::HandshakeResponse & packet) + [[maybe_unused]] size_t packet_size, + [[maybe_unused]] char * buf, + [[maybe_unused]] size_t pos, + [[maybe_unused]] std::function read_bytes, + [[maybe_unused]] MySQLProtocol::ConnectionPhase::HandshakeResponse & packet) { throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "Client requested SSL, while it is disabled."); } @@ -467,10 +499,9 @@ MySQLHandlerSSL::MySQLHandlerSSL( uint32_t connection_id_, RSA & public_key_, RSA & private_key_) - : MySQLHandler(server_, tcp_server_, socket_, ssl_enabled, connection_id_) - , public_key(public_key_) - , private_key(private_key_) -{} + : MySQLHandler(server_, tcp_server_, socket_, ssl_enabled, connection_id_), public_key(public_key_), private_key(private_key_) +{ +} void MySQLHandlerSSL::authPluginSSL() { @@ -478,7 +509,10 @@ void MySQLHandlerSSL::authPluginSSL() } void MySQLHandlerSSL::finishHandshakeSSL( - size_t packet_size, char *buf, size_t pos, std::function read_bytes, + size_t packet_size, + char * buf, + size_t pos, + std::function read_bytes, MySQLProtocol::ConnectionPhase::HandshakeResponse & packet) { read_bytes(packet_size); /// Reading rest SSLRequest. @@ -508,8 +542,8 @@ static bool isFederatedServerSetupSetCommand(const String & query) "|(^(SET AUTOCOMMIT(.*)))" "|(^(SET sql_mode(.*)))" "|(^(SET @@(.*)))" - "|(^(SET SESSION TRANSACTION ISOLATION LEVEL(.*)))" - , std::regex::icase}; + "|(^(SET SESSION TRANSACTION ISOLATION LEVEL(.*)))", + std::regex::icase}; return 1 == std::regex_match(query, expr); } diff --git a/src/Server/MySQLHandler.h b/src/Server/MySQLHandler.h index 6b8cc56a46e..a412b647ae2 100644 --- a/src/Server/MySQLHandler.h +++ b/src/Server/MySQLHandler.h @@ -1,12 +1,12 @@ #pragma once -#include -#include -#include #include -#include #include +#include #include +#include +#include +#include #include "IServer.h" #include "config.h" @@ -19,7 +19,7 @@ namespace CurrentMetrics { - extern const Metric MySQLConnection; +extern const Metric MySQLConnection; } namespace DB @@ -32,11 +32,7 @@ class MySQLHandler : public Poco::Net::TCPServerConnection { public: MySQLHandler( - IServer & server_, - TCPServer & tcp_server_, - const Poco::Net::StreamSocket & socket_, - bool ssl_enabled, - uint32_t connection_id_); + IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, bool ssl_enabled, uint32_t connection_id_); void run() final; @@ -46,7 +42,7 @@ protected: /// Enables SSL, if client requested. void finishHandshake(MySQLProtocol::ConnectionPhase::HandshakeResponse &); - void comQuery(ReadBuffer & payload); + void comQuery(ReadBuffer & payload, bool use_binary_protocol_result_set); void comFieldList(ReadBuffer & payload); @@ -63,7 +59,12 @@ protected: void comStmtClose(ReadBuffer & payload); virtual void authPluginSSL(); - virtual void finishHandshakeSSL(size_t packet_size, char * buf, size_t pos, std::function read_bytes, MySQLProtocol::ConnectionPhase::HandshakeResponse & packet); + virtual void finishHandshakeSSL( + size_t packet_size, + char * buf, + size_t pos, + std::function read_bytes, + MySQLProtocol::ConnectionPhase::HandshakeResponse & packet); IServer & server; TCPServer & tcp_server; @@ -109,8 +110,11 @@ private: void authPluginSSL() override; void finishHandshakeSSL( - size_t packet_size, char * buf, size_t pos, - std::function read_bytes, MySQLProtocol::ConnectionPhase::HandshakeResponse & packet) override; + size_t packet_size, + char * buf, + size_t pos, + std::function read_bytes, + MySQLProtocol::ConnectionPhase::HandshakeResponse & packet) override; RSA & public_key; RSA & private_key; diff --git a/src/Storages/System/InformationSchema/key_column_usage.sql b/src/Storages/System/InformationSchema/key_column_usage.sql new file mode 100644 index 00000000000..43630b8c8b9 --- /dev/null +++ b/src/Storages/System/InformationSchema/key_column_usage.sql @@ -0,0 +1,38 @@ +ATTACH VIEW key_column_usage + ( + `referenced_table_schema` Nullable(String), + `referenced_table_name` Nullable(String), + `referenced_column_name` Nullable(String), + `table_schema` String, + `table_name` String, + `column_name` Nullable(String), + `ordinal_position` UInt32, + `constraint_name` Nullable(String), + `REFERENCED_TABLE_SCHEMA` Nullable(String), + `REFERENCED_TABLE_NAME` Nullable(String), + `REFERENCED_COLUMN_NAME` Nullable(String), + `TABLE_SCHEMA` String, + `TABLE_NAME` String, + `COLUMN_NAME` Nullable(String), + `ORDINAL_POSITION` UInt32, + `CONSTRAINT_NAME` Nullable(String) + ) AS +SELECT NULL AS `referenced_table_schema`, + NULL AS `referenced_table_name`, + NULL AS `referenced_column_name`, + database AS `table_schema`, + table AS `table_name`, + name AS `column_name`, + position AS `ordinal_position`, + 'PRIMARY' AS `constraint_name`, + + `referenced_table_schema` AS `REFERENCED_TABLE_SCHEMA`, + `referenced_table_name` AS `REFERENCED_TABLE_NAME`, + `referenced_column_name` AS `REFERENCED_COLUMN_NAME`, + `table_schema` AS `TABLE_SCHEMA`, + `table_name` AS `TABLE_NAME`, + `column_name` AS `COLUMN_NAME`, + `ordinal_position` AS `ORDINAL_POSITION`, + `constraint_name` AS `CONSTRAINT_NAME` +FROM system.columns +WHERE is_in_primary_key; \ No newline at end of file diff --git a/src/Storages/System/InformationSchema/referential_constraints.sql b/src/Storages/System/InformationSchema/referential_constraints.sql new file mode 100644 index 00000000000..8216b8fff83 --- /dev/null +++ b/src/Storages/System/InformationSchema/referential_constraints.sql @@ -0,0 +1,25 @@ +ATTACH VIEW referential_constraints + ( + `constraint_name` Nullable(String), + `constraint_schema` String, + `table_name` String, + `update_rule` String, + `delete_rule` String, + `CONSTRAINT_NAME` Nullable(String), + `CONSTRAINT_SCHEMA` String, + `TABLE_NAME` String, + `UPDATE_RULE` String, + `DELETE_RULE` String + ) AS +SELECT NULL AS `constraint_name`, + '' AS `constraint_schema`, + '' AS `table_name`, + '' AS `update_rule`, + '' AS `delete_rule`, + + NULL AS `CONSTRAINT_NAME`, + '' AS `CONSTRAINT_SCHEMA`, + '' AS `TABLE_NAME`, + '' AS `UPDATE_RULE`, + '' AS `DELETE_RULE` +WHERE false; \ No newline at end of file diff --git a/src/Storages/System/InformationSchema/schemata.sql b/src/Storages/System/InformationSchema/schemata.sql index 9686fcbf4fa..ca4ad4f7310 100644 --- a/src/Storages/System/InformationSchema/schemata.sql +++ b/src/Storages/System/InformationSchema/schemata.sql @@ -1,26 +1,33 @@ ATTACH VIEW schemata -( - `catalog_name` String, - `schema_name` String, - `schema_owner` String, - `default_character_set_catalog` Nullable(String), - `default_character_set_schema` Nullable(String), - `default_character_set_name` Nullable(String), - `sql_path` Nullable(String), - `CATALOG_NAME` String ALIAS catalog_name, - `SCHEMA_NAME` String ALIAS schema_name, - `SCHEMA_OWNER` String ALIAS schema_owner, - `DEFAULT_CHARACTER_SET_CATALOG` Nullable(String) ALIAS default_character_set_catalog, - `DEFAULT_CHARACTER_SET_SCHEMA` Nullable(String) ALIAS default_character_set_schema, - `DEFAULT_CHARACTER_SET_NAME` Nullable(String) ALIAS default_character_set_name, - `SQL_PATH` Nullable(String) ALIAS sql_path -) AS -SELECT - name AS catalog_name, - name AS schema_name, - 'default' AS schema_owner, - NULL AS default_character_set_catalog, - NULL AS default_character_set_schema, - NULL AS default_character_set_name, - NULL AS sql_path + ( + `catalog_name` String, + `schema_name` String, + `schema_owner` String, + `default_character_set_catalog` Nullable(String), + `default_character_set_schema` Nullable(String), + `default_character_set_name` Nullable(String), + `sql_path` Nullable(String), + `CATALOG_NAME` String, + `SCHEMA_NAME` String, + `SCHEMA_OWNER` String, + `DEFAULT_CHARACTER_SET_CATALOG` Nullable(String), + `DEFAULT_CHARACTER_SET_SCHEMA` Nullable(String), + `DEFAULT_CHARACTER_SET_NAME` Nullable(String), + `SQL_PATH` Nullable(String) + ) AS +SELECT name AS `catalog_name`, + name AS `schema_name`, + 'default' AS `schema_owner`, + NULL AS `default_character_set_catalog`, + NULL AS `default_character_set_schema`, + NULL AS `default_character_set_name`, + NULL AS `sql_path`, + + catalog_name AS `CATALOG_NAME`, + schema_name AS `SCHEMA_NAME`, + schema_owner AS `SCHEMA_OWNER`, + NULL AS `DEFAULT_CHARACTER_SET_CATALOG`, + NULL AS `DEFAULT_CHARACTER_SET_SCHEMA`, + NULL AS `DEFAULT_CHARACTER_SET_NAME`, + NULL AS `SQL_PATH` FROM system.databases diff --git a/src/Storages/System/InformationSchema/tables.sql b/src/Storages/System/InformationSchema/tables.sql index 8eea3713923..b3bbfa72517 100644 --- a/src/Storages/System/InformationSchema/tables.sql +++ b/src/Storages/System/InformationSchema/tables.sql @@ -1,17 +1,35 @@ ATTACH VIEW tables -( - `table_catalog` String, - `table_schema` String, - `table_name` String, - `table_type` Enum8('BASE TABLE' = 1, 'VIEW' = 2, 'FOREIGN TABLE' = 3, 'LOCAL TEMPORARY' = 4, 'SYSTEM VIEW' = 5), - `TABLE_CATALOG` String ALIAS table_catalog, - `TABLE_SCHEMA` String ALIAS table_schema, - `TABLE_NAME` String ALIAS table_name, - `TABLE_TYPE` Enum8('BASE TABLE' = 1, 'VIEW' = 2, 'FOREIGN TABLE' = 3, 'LOCAL TEMPORARY' = 4, 'SYSTEM VIEW' = 5) ALIAS table_type -) AS -SELECT - database AS table_catalog, - database AS table_schema, - name AS table_name, - multiIf(is_temporary, 4, engine like '%View', 2, engine LIKE 'System%', 5, has_own_data = 0, 3, 1) AS table_type -FROM system.tables + ( + `table_catalog` String, + `table_schema` String, + `table_name` String, + `table_type` String, + `table_comment` String, + `table_collation` String, + `TABLE_CATALOG` String, + `TABLE_SCHEMA` String, + `TABLE_NAME` String, + `TABLE_TYPE` String, + `TABLE_COMMENT` String, + `TABLE_COLLATION` String + ) AS +SELECT database AS `table_catalog`, + database AS `table_schema`, + name AS `table_name`, + comment AS `table_comment`, + multiIf( + is_temporary, 'LOCAL TEMPORARY', + engine LIKE '%View', 'VIEW', + engine LIKE 'System%', 'SYSTEM VIEW', + has_own_data = 0, 'FOREIGN TABLE', + 'BASE TABLE' + ) AS `table_type`, + 'utf8mb4_0900_ai_ci' AS `table_collation`, + + table_catalog AS `TABLE_CATALOG`, + table_schema AS `TABLE_SCHEMA`, + table_name AS `TABLE_NAME`, + table_comment AS `TABLE_COMMENT`, + table_type AS `TABLE_TYPE`, + table_collation AS `TABLE_COLLATION` +FROM system.tables \ No newline at end of file diff --git a/src/Storages/System/attachInformationSchemaTables.cpp b/src/Storages/System/attachInformationSchemaTables.cpp index 074a648d235..d4775bf0d4a 100644 --- a/src/Storages/System/attachInformationSchemaTables.cpp +++ b/src/Storages/System/attachInformationSchemaTables.cpp @@ -12,7 +12,8 @@ INCBIN(resource_schemata_sql, SOURCE_DIR "/src/Storages/System/InformationSchema INCBIN(resource_tables_sql, SOURCE_DIR "/src/Storages/System/InformationSchema/tables.sql"); INCBIN(resource_views_sql, SOURCE_DIR "/src/Storages/System/InformationSchema/views.sql"); INCBIN(resource_columns_sql, SOURCE_DIR "/src/Storages/System/InformationSchema/columns.sql"); - +INCBIN(resource_key_column_usage_sql, SOURCE_DIR "/src/Storages/System/InformationSchema/key_column_usage.sql"); +INCBIN(resource_referential_constraints_sql, SOURCE_DIR "/src/Storages/System/InformationSchema/referential_constraints.sql"); namespace DB { @@ -66,6 +67,8 @@ void attachInformationSchema(ContextMutablePtr context, IDatabase & information_ createInformationSchemaView(context, information_schema_database, "tables", std::string_view(reinterpret_cast(gresource_tables_sqlData), gresource_tables_sqlSize)); createInformationSchemaView(context, information_schema_database, "views", std::string_view(reinterpret_cast(gresource_views_sqlData), gresource_views_sqlSize)); createInformationSchemaView(context, information_schema_database, "columns", std::string_view(reinterpret_cast(gresource_columns_sqlData), gresource_columns_sqlSize)); + createInformationSchemaView(context, information_schema_database, "key_column_usage", std::string_view(reinterpret_cast(gresource_key_column_usage_sqlData), gresource_key_column_usage_sqlSize)); + createInformationSchemaView(context, information_schema_database, "referential_constraints", std::string_view(reinterpret_cast(gresource_referential_constraints_sqlData), gresource_referential_constraints_sqlSize)); } }