From 9c42f7f113d670bc891145b5bb827f983ea73066 Mon Sep 17 00:00:00 2001 From: Azat Khuzhin Date: Thu, 11 May 2023 12:46:16 +0200 Subject: [PATCH] Fix settings aliases in native protocol The initial test (tests/queries/0_stateless/02539_settings_alias.sh) works only because of the clickhouse-client, while in native protocol aliases does not work. Signed-off-by: Azat Khuzhin --- src/Core/BaseSettings.h | 6 +- .../02750_settings_alias_tcp_protocol.python | 218 ++++++++++++++++++ ...2750_settings_alias_tcp_protocol.reference | 1 + .../02750_settings_alias_tcp_protocol.sh | 9 + 4 files changed, 232 insertions(+), 2 deletions(-) create mode 100644 tests/queries/0_stateless/02750_settings_alias_tcp_protocol.python create mode 100644 tests/queries/0_stateless/02750_settings_alias_tcp_protocol.reference create mode 100755 tests/queries/0_stateless/02750_settings_alias_tcp_protocol.sh diff --git a/src/Core/BaseSettings.h b/src/Core/BaseSettings.h index 521422b780e..a14cec9cc7d 100644 --- a/src/Core/BaseSettings.h +++ b/src/Core/BaseSettings.h @@ -501,9 +501,11 @@ void BaseSettings::read(ReadBuffer & in, SettingsWriteFormat format) const auto & accessor = Traits::Accessor::instance(); while (true) { - String name = BaseSettingsHelpers::readString(in); - if (name.empty() /* empty string is a marker of the end of settings */) + String read_name = BaseSettingsHelpers::readString(in); + if (read_name.empty() /* empty string is a marker of the end of settings */) break; + + std::string_view name = TTraits::resolveName(read_name); size_t index = accessor.find(name); using Flags = BaseSettingsHelpers::Flags; diff --git a/tests/queries/0_stateless/02750_settings_alias_tcp_protocol.python b/tests/queries/0_stateless/02750_settings_alias_tcp_protocol.python new file mode 100644 index 00000000000..768fb2144e3 --- /dev/null +++ b/tests/queries/0_stateless/02750_settings_alias_tcp_protocol.python @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 + +import socket +import os +import uuid +import json + +CLICKHOUSE_HOST = os.environ.get("CLICKHOUSE_HOST", "127.0.0.1") +CLICKHOUSE_PORT = int(os.environ.get("CLICKHOUSE_PORT_TCP", "900000")) +CLICKHOUSE_DATABASE = os.environ.get("CLICKHOUSE_DATABASE", "default") + + +def writeVarUInt(x, ba): + for _ in range(0, 9): + byte = x & 0x7F + if x > 0x7F: + byte |= 0x80 + + ba.append(byte) + + x >>= 7 + if x == 0: + return + + +def writeStringBinary(s, ba): + b = bytes(s, "utf-8") + writeVarUInt(len(s), ba) + ba.extend(b) + + +def readStrict(s, size=1): + res = bytearray() + while size: + cur = s.recv(size) + # if not res: + # raise "Socket is closed" + size -= len(cur) + res.extend(cur) + + return res + + +def readUInt(s, size=1): + res = readStrict(s, size) + val = 0 + for i in range(len(res)): + val += res[i] << (i * 8) + return val + + +def readUInt8(s): + return readUInt(s) + + +def readUInt16(s): + return readUInt(s, 2) + + +def readUInt32(s): + return readUInt(s, 4) + + +def readUInt64(s): + return readUInt(s, 8) + + +def readVarUInt(s): + x = 0 + for i in range(9): + byte = readStrict(s)[0] + x |= (byte & 0x7F) << (7 * i) + + if not byte & 0x80: + return x + + return x + + +def readStringBinary(s): + size = readVarUInt(s) + s = readStrict(s, size) + return s.decode("utf-8") + + +def sendHello(s): + ba = bytearray() + writeVarUInt(0, ba) # Hello + writeStringBinary("simple native protocol", ba) + writeVarUInt(21, ba) + writeVarUInt(9, ba) + writeVarUInt(54449, ba) + writeStringBinary(CLICKHOUSE_DATABASE, ba) # database + writeStringBinary("default", ba) # user + writeStringBinary("", ba) # pwd + s.sendall(ba) + + +def receiveHello(s): + p_type = readVarUInt(s) + assert p_type == 0 # Hello + _server_name = readStringBinary(s) + _server_version_major = readVarUInt(s) + _server_version_minor = readVarUInt(s) + _server_revision = readVarUInt(s) + _server_timezone = readStringBinary(s) + _server_display_name = readStringBinary(s) + _server_version_patch = readVarUInt(s) + + +def serializeClientInfo(ba, query_id): + writeStringBinary("default", ba) # initial_user + writeStringBinary(query_id, ba) # initial_query_id + writeStringBinary("127.0.0.1:9000", ba) # initial_address + ba.extend([0] * 8) # initial_query_start_time_microseconds + ba.append(1) # TCP + writeStringBinary("os_user", ba) # os_user + writeStringBinary("client_hostname", ba) # client_hostname + writeStringBinary("client_name", ba) # client_name + writeVarUInt(21, ba) + writeVarUInt(9, ba) + writeVarUInt(54449, ba) + writeStringBinary("", ba) # quota_key + writeVarUInt(0, ba) # distributed_depth + writeVarUInt(1, ba) # client_version_patch + ba.append(0) # No telemetry + + +def sendQuery(s, query, settings): + ba = bytearray() + query_id = uuid.uuid4().hex + writeVarUInt(1, ba) # query + writeStringBinary(query_id, ba) + + ba.append(1) # INITIAL_QUERY + + # client info + serializeClientInfo(ba, query_id) + + # Settings + for key, value in settings.items(): + writeStringBinary(key, ba) + writeVarUInt(1, ba) # is_important + writeStringBinary(str(value), ba) + writeStringBinary("", ba) # End of settings + + writeStringBinary("", ba) # No interserver secret + writeVarUInt(2, ba) # Stage - Complete + ba.append(0) # No compression + writeStringBinary(query, ba) # query, finally + s.sendall(ba) + + +def serializeBlockInfo(ba): + writeVarUInt(1, ba) # 1 + ba.append(0) # is_overflows + writeVarUInt(2, ba) # 2 + writeVarUInt(0, ba) # 0 + ba.extend([0] * 4) # bucket_num + + +def sendEmptyBlock(s): + ba = bytearray() + writeVarUInt(2, ba) # Data + writeStringBinary("", ba) + serializeBlockInfo(ba) + writeVarUInt(0, ba) # rows + writeVarUInt(0, ba) # columns + s.sendall(ba) + + +def assertPacket(packet, expected): + assert packet == expected, "Got: {}, expected: {}".format(packet, expected) + + +def readResponse(s): + packet_type = readVarUInt(s) + if packet_type == 2: # Exception + raise RuntimeError(readException(s)) + + if packet_type == 1: # Data + return None + if packet_type == 3: # Progress + return None + if packet_type == 5: # End stream + return None + + raise RuntimeError("Unexpected packet: {}".format(packet_type)) + + +def readException(s): + code = readUInt32(s) + _name = readStringBinary(s) + text = readStringBinary(s) + readStringBinary(s) # trace + assertPacket(readUInt8(s), 0) # has_nested + return "code {}: {}".format(code, text.replace("DB::Exception:", "")) + + +def main(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(30) + s.connect((CLICKHOUSE_HOST, CLICKHOUSE_PORT)) + sendHello(s) + receiveHello(s) + sendQuery(s, "select 1", {"replication_alter_partitions_sync": 1}) + # external tables + sendEmptyBlock(s) + + while readResponse(s) is not None: + pass + + s.close() + print("OK") + + +if __name__ == "__main__": + main() diff --git a/tests/queries/0_stateless/02750_settings_alias_tcp_protocol.reference b/tests/queries/0_stateless/02750_settings_alias_tcp_protocol.reference new file mode 100644 index 00000000000..d86bac9de59 --- /dev/null +++ b/tests/queries/0_stateless/02750_settings_alias_tcp_protocol.reference @@ -0,0 +1 @@ +OK diff --git a/tests/queries/0_stateless/02750_settings_alias_tcp_protocol.sh b/tests/queries/0_stateless/02750_settings_alias_tcp_protocol.sh new file mode 100755 index 00000000000..35d685c1580 --- /dev/null +++ b/tests/queries/0_stateless/02750_settings_alias_tcp_protocol.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +# NOTE: this sh wrapper is required because of shell_config + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +python3 "$CURDIR"/02750_settings_alias_tcp_protocol.python