ClickHouse/tests/queries/0_stateless/helpers/tcp_client.py

314 lines
9.2 KiB
Python

import socket
import os
import uuid
import struct
CLICKHOUSE_HOST = os.environ.get("CLICKHOUSE_HOST", "127.0.0.1")
CLICKHOUSE_PORT = int(os.environ.get("CLICKHOUSE_PORT_TCP", "900000"))
CLICKHOUSE_DATABASE = os.environ.get("CLICKHOUSE_DATABASE", "default")
CLIENT_NAME = "simple native protocol"
def writeVarUInt(x, ba):
for _ in range(0, 9):
byte = x & 0x7F
if x > 0x7F:
byte |= 0x80
ba.append(byte)
x >>= 7
if x == 0:
return
def writeStringBinary(s, ba):
b = bytes(s, "utf-8")
writeVarUInt(len(s), ba)
ba.extend(b)
def serializeClientInfo(ba, query_id):
writeStringBinary("default", ba) # initial_user
writeStringBinary(query_id, ba) # initial_query_id
writeStringBinary("127.0.0.1:9000", ba) # initial_address
ba.extend([0] * 8) # initial_query_start_time_microseconds
ba.append(1) # TCP
writeStringBinary("os_user", ba) # os_user
writeStringBinary("client_hostname", ba) # client_hostname
writeStringBinary(CLIENT_NAME, ba) # client_name
writeVarUInt(21, ba)
writeVarUInt(9, ba)
writeVarUInt(54449, ba)
writeStringBinary("", ba) # quota_key
writeVarUInt(0, ba) # distributed_depth
writeVarUInt(1, ba) # client_version_patch
ba.append(0) # No telemetry
def serializeBlockInfo(ba):
writeVarUInt(1, ba) # 1
ba.append(0) # is_overflows
writeVarUInt(2, ba) # 2
writeVarUInt(0, ba) # 0
ba.extend([0] * 4) # bucket_num
def assertPacket(packet, expected):
assert packet == expected, "Got: {}, expected: {}".format(packet, expected)
class Data(object):
def __init__(self, key, value):
self.key = key
self.value = value
class TCPClient(object):
def __init__(self, timeout=30):
self.timeout = timeout
self.socket = None
def __enter__(self):
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.settimeout(self.timeout)
self.socket.connect((CLICKHOUSE_HOST, CLICKHOUSE_PORT))
self.sendHello()
self.receiveHello()
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.socket:
self.socket.close()
def readStrict(self, size=1):
res = bytearray()
while size:
cur = self.socket.recv(size)
# if not res:
# raise "Socket is closed"
size -= len(cur)
res.extend(cur)
return res
def readUInt(self, size=1):
res = self.readStrict(size)
val = 0
for i in range(len(res)):
val += res[i] << (i * 8)
return val
def readUInt8(self):
return self.readUInt()
def readUInt16(self):
return self.readUInt(2)
def readUInt32(self):
return self.readUInt(4)
def readUInt64(self):
return self.readUInt(8)
def readFloat16(self):
return struct.unpack("e", self.readStrict(2))
def readFloat32(self):
return struct.unpack("f", self.readStrict(4))
def readFloat64(self):
return struct.unpack("d", self.readStrict(8))
def readVarUInt(self):
x = 0
for i in range(9):
byte = self.readStrict()[0]
x |= (byte & 0x7F) << (7 * i)
if not byte & 0x80:
return x
return x
def readStringBinary(self):
size = self.readVarUInt()
s = self.readStrict(size)
return s.decode("utf-8")
def send(self, byte_array):
self.socket.sendall(byte_array)
def sendHello(self):
ba = bytearray()
writeVarUInt(0, ba) # Hello
writeStringBinary(CLIENT_NAME, ba)
writeVarUInt(21, ba)
writeVarUInt(9, ba)
writeVarUInt(54449, ba)
writeStringBinary(CLICKHOUSE_DATABASE, ba) # database
writeStringBinary("default", ba) # user
writeStringBinary("", ba) # pwd
self.send(ba)
def receiveHello(self):
p_type = self.readVarUInt()
assert p_type == 0 # Hello
_server_name = self.readStringBinary()
_server_version_major = self.readVarUInt()
_server_version_minor = self.readVarUInt()
_server_revision = self.readVarUInt()
_server_timezone = self.readStringBinary()
_server_display_name = self.readStringBinary()
_server_version_patch = self.readVarUInt()
def sendQuery(self, query, settings=None):
if settings == None:
settings = {} # No settings
ba = bytearray()
query_id = uuid.uuid4().hex
writeVarUInt(1, ba) # query
writeStringBinary(query_id, ba)
ba.append(1) # INITIAL_QUERY
# client info
serializeClientInfo(ba, query_id)
# Settings
for key, value in settings.items():
writeStringBinary(key, ba)
writeVarUInt(1, ba) # is_important
writeStringBinary(str(value), ba)
writeStringBinary("", ba) # End of settings
writeStringBinary("", ba) # No interserver secret
writeVarUInt(2, ba) # Stage - Complete
ba.append(0) # No compression
writeStringBinary(query, ba) # query, finally
self.send(ba)
def sendEmptyBlock(self):
ba = bytearray()
writeVarUInt(2, ba) # Data
writeStringBinary("", ba)
serializeBlockInfo(ba)
writeVarUInt(0, ba) # rows
writeVarUInt(0, ba) # columns
self.send(ba)
def readException(self):
code = self.readUInt32()
_name = self.readStringBinary()
text = self.readStringBinary()
self.readStringBinary() # trace
assertPacket(self.readUInt8(), 0) # has_nested
return "code {}: {}".format(code, text.replace("DB::Exception:", ""))
def readPacketType(self):
packet_type = self.readVarUInt()
if packet_type == 2: # Exception
raise RuntimeError(self.readException())
return packet_type
def readResponse(self):
packet_type = self.readPacketType()
if packet_type == 1: # Data
return None
if packet_type == 3: # Progress
return None
if packet_type == 5: # End stream
return None
raise RuntimeError("Unexpected packet: {}".format(packet_type))
def readProgressData(self):
read_rows = self.readVarUInt()
read_bytes = self.readVarUInt()
total_rows_to_read = self.readVarUInt()
written_rows = self.readVarUInt()
written_bytes = self.readVarUInt()
return read_rows, read_bytes, total_rows_to_read, written_rows, written_bytes
def readProgress(self):
packet_type = self.readPacketType()
if packet_type == 5: # End stream
return None
assertPacket(packet_type, 3) # Progress
return self.readProgressData()
def readHeaderInfo(self):
self.readStringBinary() # external table name
# BlockInfo
assertPacket(self.readVarUInt(), 1) # field number 1
assertPacket(self.readUInt8(), 0) # is_overflows
assertPacket(self.readVarUInt(), 2) # field number 2
assertPacket(self.readUInt32(), 4294967295) # bucket_num
assertPacket(self.readVarUInt(), 0) # 0
columns = self.readVarUInt() # rows
rows = self.readVarUInt() # columns
return columns, rows
def readHeader(self):
packet_type = self.readPacketType()
assertPacket(packet_type, 1) # Data
columns, rows = self.readHeaderInfo()
print("Rows {} Columns {}".format(rows, columns))
for _ in range(columns):
col_name = self.readStringBinary()
type_name = self.readStringBinary()
print("Column {} type {}".format(col_name, type_name))
def readRow(self, row_type, rows):
supported_row_types = {
"UInt8": self.readUInt8,
"UInt16": self.readUInt16,
"UInt32": self.readUInt32,
"UInt64": self.readUInt64,
"Float16": self.readFloat16,
"Float32": self.readFloat32,
"Float64": self.readFloat64,
}
if row_type in supported_row_types:
read_type = supported_row_types[row_type]
row = [read_type() for _ in range(rows)]
return row
else:
raise RuntimeError(
"Current python version of tcp client doesn't support the following type of row: {}".format(
row_type
)
)
def readDataWithoutProgress(self, need_print_info=True):
packet_type = self.readPacketType()
while packet_type == 3: # Progress
self.readProgressData()
packet_type = self.readPacketType()
if packet_type == 5: # End stream
return None
assertPacket(packet_type, 1) # Data
columns, rows = self.readHeaderInfo()
data = []
if need_print_info:
print("Rows {} Columns {}".format(rows, columns))
for _ in range(columns):
col_name = self.readStringBinary()
type_name = self.readStringBinary()
if need_print_info:
print("Column {} type {}".format(col_name, type_name))
data.append(Data(col_name, self.readRow(type_name, rows)))
return data