#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "MySQLHandler.h" #include #include #include namespace DB { using namespace MySQLProtocol; using Poco::Net::SecureStreamSocket; using Poco::Net::SSLManager; namespace ErrorCodes { extern const int MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES; extern const int OPENSSL_ERROR; } MySQLHandler::MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & socket_, RSA & public_key, RSA & private_key, bool ssl_enabled, size_t connection_id) : Poco::Net::TCPServerConnection(socket_) , server(server_) , log(&Poco::Logger::get("MySQLHandler")) , connection_context(server.context()) , connection_id(connection_id) , public_key(public_key) , private_key(private_key) { server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF; if (ssl_enabled) server_capability_flags |= CLIENT_SSL; } void MySQLHandler::run() { connection_context = server.context(); connection_context.makeSessionContext(); connection_context.setDefaultFormat("MySQLWire"); in = std::make_shared(socket()); out = std::make_shared(socket()); packet_sender = std::make_shared(*in, *out, connection_context.sequence_id); try { String scramble = generateScramble(); /** Native authentication sent 20 bytes + '\0' character = 21 bytes. * This plugin must do the same to stay consistent with historical behavior if it is set to operate as a default plugin. * https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L3994 */ Handshake handshake(server_capability_flags, connection_id, VERSION_STRING + String("-") + VERSION_NAME, scramble + '\0'); packet_sender->sendPacket(handshake, true); LOG_TRACE(log, "Sent handshake"); HandshakeResponse handshake_response = finishHandshake(); connection_context.client_capabilities = handshake_response.capability_flags; if (handshake_response.max_packet_size) connection_context.max_packet_size = handshake_response.max_packet_size; if (!connection_context.max_packet_size) connection_context.max_packet_size = MAX_PACKET_LENGTH; LOG_DEBUG(log, "Capabilities: " << handshake_response.capability_flags << "\nmax_packet_size: " << handshake_response.max_packet_size << "\ncharacter_set: " << handshake_response.character_set << "\nuser: " << handshake_response.username << "\nauth_response length: " << handshake_response.auth_response.length() << "\nauth_response: " << handshake_response.auth_response << "\ndatabase: " << handshake_response.database << "\nauth_plugin_name: " << handshake_response.auth_plugin_name); client_capability_flags = handshake_response.capability_flags; if (!(client_capability_flags & CLIENT_PROTOCOL_41)) throw Exception("Required capability: CLIENT_PROTOCOL_41.", ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES); if (!(client_capability_flags & CLIENT_PLUGIN_AUTH)) throw Exception("Required capability: CLIENT_PLUGIN_AUTH.", ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES); authenticate(handshake_response, scramble); OK_Packet ok_packet(0, handshake_response.capability_flags, 0, 0, 0); packet_sender->sendPacket(ok_packet, true); while (true) { packet_sender->resetSequenceId(); String payload = packet_sender->receivePacketPayload(); int command = payload[0]; LOG_DEBUG(log, "Received command: " << std::to_string(command) << ". Connection id: " << connection_id << "."); try { switch (command) { case COM_QUIT: return; case COM_INIT_DB: comInitDB(payload); break; case COM_QUERY: comQuery(payload); break; case COM_FIELD_LIST: comFieldList(payload); break; case COM_PING: comPing(); break; default: throw Exception(Poco::format("Command %d is not implemented.", command), ErrorCodes::NOT_IMPLEMENTED); } } catch (const NetException & exc) { log->log(exc); throw; } catch (const Exception & exc) { log->log(exc); packet_sender->sendPacket(ERR_Packet(exc.code(), "00000", exc.message()), true); } } } catch (Poco::Exception & exc) { log->log(exc); } } /** Reads 3 bytes, finds out whether it is SSLRequest or HandshakeResponse packet, starts secure connection, if it is SSLRequest. * Reading is performed from socket instead of ReadBuffer to prevent reading part of SSL handshake. * If we read it from socket, it will be impossible to start SSL connection using Poco. Size of SSLRequest packet payload is 32 bytes, thus we can read at most 36 bytes. */ MySQLProtocol::HandshakeResponse MySQLHandler::finishHandshake() { HandshakeResponse packet; size_t packet_size = PACKET_HEADER_SIZE + SSL_REQUEST_PAYLOAD_SIZE; /// Buffer for SSLRequest or part of HandshakeResponse. char buf[packet_size]; size_t pos = 0; /// Reads at least count and at most packet_size bytes. auto read_bytes = [this, &buf, &pos, &packet_size](size_t count) -> void { while (pos < count) { int ret = socket().receiveBytes(buf + pos, packet_size - pos); if (ret == 0) { throw Exception("Cannot read all data. Bytes read: " + std::to_string(pos) + ". Bytes expected: 3.", ErrorCodes::CANNOT_READ_ALL_DATA); } pos += ret; } }; read_bytes(3); /// We can find out whether it is SSLRequest of HandshakeResponse by first 3 bytes. size_t payload_size = unalignedLoad(buf) & 0xFFFFFFu; LOG_TRACE(log, "payload size: " << payload_size); if (payload_size == SSL_REQUEST_PAYLOAD_SIZE) { read_bytes(packet_size); /// Reading rest SSLRequest. SSLRequest ssl_request; ssl_request.readPayload(String(buf + PACKET_HEADER_SIZE, pos - PACKET_HEADER_SIZE)); connection_context.client_capabilities = ssl_request.capability_flags; connection_context.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH; secure_connection = true; ss = std::make_shared(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext())); in = std::make_shared(*ss); out = std::make_shared(*ss); connection_context.sequence_id = 2; packet_sender = std::make_shared(*in, *out, connection_context.sequence_id); packet_sender->max_packet_size = connection_context.max_packet_size; packet_sender->receivePacket(packet); /// Reading HandshakeResponse from secure socket. } else { /// Reading rest of HandshakeResponse. packet_size = PACKET_HEADER_SIZE + payload_size; WriteBufferFromOwnString buf_for_handshake_response; buf_for_handshake_response.write(buf, pos); copyData(*packet_sender->in, buf_for_handshake_response, packet_size - pos); packet.readPayload(buf_for_handshake_response.str().substr(PACKET_HEADER_SIZE)); packet_sender->sequence_id++; } return packet; } String MySQLHandler::generateScramble() { String scramble(MySQLProtocol::SCRAMBLE_LENGTH, 0); Poco::RandomInputStream generator; for (size_t i = 0; i < scramble.size(); i++) { generator >> scramble[i]; } return scramble; } void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, const String & scramble) { String auth_response; AuthSwitchResponse response; if (handshake_response.auth_plugin_name != Authentication::SHA256) { packet_sender->sendPacket(AuthSwitchRequest(Authentication::SHA256, scramble + '\0'), true); if (in->eof()) throw Exception( "Client doesn't support authentication method " + String(Authentication::SHA256) + " used by ClickHouse", ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES); packet_sender->receivePacket(response); auth_response = response.value; LOG_TRACE(log, "Authentication method mismatch."); } else { auth_response = handshake_response.auth_response; LOG_TRACE(log, "Authentication method match."); } if (auth_response == "\1") { LOG_TRACE(log, "Client requests public key."); BIO * mem = BIO_new(BIO_s_mem()); SCOPE_EXIT(BIO_free(mem)); if (PEM_write_bio_RSA_PUBKEY(mem, &public_key) != 1) { throw Exception("Failed to write public key to memory. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR); } char * pem_buf = nullptr; long pem_size = BIO_get_mem_data(mem, &pem_buf); String pem(pem_buf, pem_size); LOG_TRACE(log, "Key: " << pem); AuthMoreData data(pem); packet_sender->sendPacket(data, true); packet_sender->receivePacket(response); auth_response = response.value; } else { LOG_TRACE(log, "Client didn't request public key."); } String password; /** Decrypt password, if it's not empty. * The original intention was that the password is a string[NUL] but this never got enforced properly so now we have to accept that * an empty packet is a blank password, thus the check for auth_response.empty() has to be made too. * https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L4017 */ if (!secure_connection && !auth_response.empty() && auth_response != String("\0", 1)) { LOG_TRACE(log, "Received nonempty password"); auto ciphertext = reinterpret_cast(auth_response.data()); unsigned char plaintext[RSA_size(&private_key)]; int plaintext_size = RSA_private_decrypt(auth_response.size(), ciphertext, plaintext, &private_key, RSA_PKCS1_OAEP_PADDING); if (plaintext_size == -1) { throw Exception("Failed to decrypt auth data. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR); } password.resize(plaintext_size); for (int i = 0; i < plaintext_size; i++) { password[i] = plaintext[i] ^ static_cast(scramble[i % scramble.size()]); } } else if (secure_connection) { password = auth_response; } else { LOG_TRACE(log, "Received empty password"); } if (!password.empty()) { password.pop_back(); /// terminating null byte } try { connection_context.setUser(handshake_response.username, password, socket().address(), ""); if (!handshake_response.database.empty()) connection_context.setCurrentDatabase(handshake_response.database); connection_context.setCurrentQueryId(""); LOG_ERROR(log, "Authentication for user " << handshake_response.username << " succeeded."); } catch (const Exception & exc) { LOG_ERROR(log, "Authentication for user " << handshake_response.username << " failed."); packet_sender->sendPacket(ERR_Packet(exc.code(), "00000", exc.message()), true); throw; } } void MySQLHandler::comInitDB(const String & payload) { String database = payload.substr(1); LOG_DEBUG(log, "Setting current database to " << database); connection_context.setCurrentDatabase(database); packet_sender->sendPacket(OK_Packet(0, client_capability_flags, 0, 0, 1), true); } void MySQLHandler::comFieldList(const String & payload) { ComFieldList packet; packet.readPayload(payload); String database = connection_context.getCurrentDatabase(); StoragePtr tablePtr = connection_context.getTable(database, packet.table); for (const NameAndTypePair & column: tablePtr->getColumns().getAll()) { ColumnDefinition column_definition( database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0 ); packet_sender->sendPacket(column_definition); } packet_sender->sendPacket(OK_Packet(0xfe, client_capability_flags, 0, 0, 0), true); } void MySQLHandler::comPing() { packet_sender->sendPacket(OK_Packet(0x0, client_capability_flags, 0, 0, 0), true); } void MySQLHandler::comQuery(const String & payload) { bool with_output = false; std::function set_content_type = [&with_output](const String &) -> void { with_output = true; }; String query = payload.substr(1); // Translate query from MySQL to ClickHouse. // This is a temporary workaround until ClickHouse supports the syntax "@@var_name". if (query == "select @@version_comment limit 1") // MariaDB client starts session with that query query = "select ''"; ReadBufferFromString buf(query); executeQuery(buf, *out, true, connection_context, set_content_type, nullptr); if (!with_output) packet_sender->sendPacket(OK_Packet(0x00, client_capability_flags, 0, 0, 0), true); } }