ClickHouse/tests/queries/0_stateless/02458_insert_select_progress_tcp.python

265 lines
7.0 KiB
Python

#!/usr/bin/env python3
import socket
import os
import uuid
import json
CLICKHOUSE_HOST = os.environ.get('CLICKHOUSE_HOST', '127.0.0.1')
CLICKHOUSE_PORT = int(os.environ.get('CLICKHOUSE_PORT_TCP', '900000'))
CLICKHOUSE_DATABASE = os.environ.get('CLICKHOUSE_DATABASE', 'default')
def writeVarUInt(x, ba):
for _ in range(0, 9):
byte = x & 0x7F
if x > 0x7F:
byte |= 0x80
ba.append(byte)
x >>= 7
if x == 0:
return
def writeStringBinary(s, ba):
b = bytes(s, 'utf-8')
writeVarUInt(len(s), ba)
ba.extend(b)
def readStrict(s, size = 1):
res = bytearray()
while size:
cur = s.recv(size)
# if not res:
# raise "Socket is closed"
size -= len(cur)
res.extend(cur)
return res
def readUInt(s, size=1):
res = readStrict(s, size)
val = 0
for i in range(len(res)):
val += res[i] << (i * 8)
return val
def readUInt8(s):
return readUInt(s)
def readUInt16(s):
return readUInt(s, 2)
def readUInt32(s):
return readUInt(s, 4)
def readUInt64(s):
return readUInt(s, 8)
def readVarUInt(s):
x = 0
for i in range(9):
byte = readStrict(s)[0]
x |= (byte & 0x7F) << (7 * i)
if not byte & 0x80:
return x
return x
def readStringBinary(s):
size = readVarUInt(s)
s = readStrict(s, size)
return s.decode('utf-8')
def sendHello(s):
ba = bytearray()
writeVarUInt(0, ba) # Hello
writeStringBinary('simple native protocol', ba)
writeVarUInt(21, ba)
writeVarUInt(9, ba)
writeVarUInt(54449, ba)
writeStringBinary(CLICKHOUSE_DATABASE, ba) # database
writeStringBinary('default', ba) # user
writeStringBinary('', ba) # pwd
s.sendall(ba)
def receiveHello(s):
p_type = readVarUInt(s)
assert (p_type == 0) # Hello
server_name = readStringBinary(s)
# print("Server name: ", server_name)
server_version_major = readVarUInt(s)
# print("Major: ", server_version_major)
server_version_minor = readVarUInt(s)
# print("Minor: ", server_version_minor)
server_revision = readVarUInt(s)
# print("Revision: ", server_revision)
server_timezone = readStringBinary(s)
# print("Timezone: ", server_timezone)
server_display_name = readStringBinary(s)
# print("Display name: ", server_display_name)
server_version_patch = readVarUInt(s)
# print("Version patch: ", server_version_patch)
def serializeClientInfo(ba, query_id):
writeStringBinary('default', ba) # initial_user
writeStringBinary(query_id, ba) # initial_query_id
writeStringBinary('127.0.0.1:9000', ba) # initial_address
ba.extend([0] * 8) # initial_query_start_time_microseconds
ba.append(1) # TCP
writeStringBinary('os_user', ba) # os_user
writeStringBinary('client_hostname', ba) # client_hostname
writeStringBinary('client_name', ba) # client_name
writeVarUInt(21, ba)
writeVarUInt(9, ba)
writeVarUInt(54449, ba)
writeStringBinary('', ba) # quota_key
writeVarUInt(0, ba) # distributed_depth
writeVarUInt(1, ba) # client_version_patch
ba.append(0) # No telemetry
def sendQuery(s, query):
ba = bytearray()
query_id = uuid.uuid4().hex
writeVarUInt(1, ba) # query
writeStringBinary(query_id, ba)
ba.append(1) # INITIAL_QUERY
# client info
serializeClientInfo(ba, query_id)
writeStringBinary('', ba) # No settings
writeStringBinary('', ba) # No interserver secret
writeVarUInt(2, ba) # Stage - Complete
ba.append(0) # No compression
writeStringBinary(query, ba) # query, finally
s.sendall(ba)
def serializeBlockInfo(ba):
writeVarUInt(1, ba) # 1
ba.append(0) # is_overflows
writeVarUInt(2, ba) # 2
writeVarUInt(0, ba) # 0
ba.extend([0] * 4) # bucket_num
def sendEmptyBlock(s):
ba = bytearray()
writeVarUInt(2, ba) # Data
writeStringBinary('', ba)
serializeBlockInfo(ba)
writeVarUInt(0, ba) # rows
writeVarUInt(0, ba) # columns
s.sendall(ba)
def assertPacket(packet, expected):
assert(packet == expected), packet
class Progress():
def __init__(self):
# NOTE: this is done in ctor to initialize __dict__
self.read_rows = 0
self.read_bytes = 0
self.total_rows_to_read = 0
self.written_rows = 0
self.written_bytes = 0
def __str__(self):
return json.dumps(self.__dict__)
def __add__(self, b):
self.read_rows += b.read_rows
self.read_bytes += b.read_bytes
self.total_rows_to_read += b.total_rows_to_read
self.written_rows += b.written_rows
self.written_bytes += b.written_bytes
return self
def readPacket(self, s):
self.read_rows += readVarUInt(s)
self.read_bytes += readVarUInt(s)
self.total_rows_to_read += readVarUInt(s)
self.written_rows += readVarUInt(s)
self.written_bytes += readVarUInt(s)
def __bool__(self):
return (
self.read_rows > 0 or
self.read_bytes > 0 or
self.total_rows_to_read > 0 or
self.written_rows > 0 or
self.written_bytes > 0)
def readProgress(s):
packet_type = readVarUInt(s)
if packet_type == 2: # Exception
raise RuntimeError(readException(s))
if packet_type == 5: # End stream
return None
assertPacket(packet_type, 3) # Progress
progress = Progress()
progress.readPacket(s)
return progress
def readException(s):
code = readUInt32(s)
name = readStringBinary(s)
text = readStringBinary(s)
readStringBinary(s) # trace
assertPacket(readUInt8(s), 0) # has_nested
return "code {}: {}".format(code, text.replace('DB::Exception:', ''))
def main():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(30)
s.connect((CLICKHOUSE_HOST, CLICKHOUSE_PORT))
sendHello(s)
receiveHello(s)
# For 1 second sleep and 1000ms of interactive_delay we definitelly should have non zero progress packet.
# NOTE: interactive_delay=0 cannot be used since in this case CompletedPipelineExecutor will not call cancelled callback.
sendQuery(s, "insert into function null('_ Int') select sleep(1) from numbers(2) settings max_block_size=1, interactive_delay=1000")
# external tables
sendEmptyBlock(s)
summary_progress = Progress()
non_empty_progress_packets = 0
while True:
progress = readProgress(s)
if progress is None:
break
summary_progress += progress
if progress:
non_empty_progress_packets += 1
print(summary_progress)
# Print only non empty progress packets, eventually we should have at least 3 of them
# - 2 for each INSERT block (one of them can be merged with read block, heance 3 or for)
# - 1 or 2 for each SELECT block
assert non_empty_progress_packets in (3, 4), f"{non_empty_progress_packets=:}"
s.close()
if __name__ == "__main__":
main()