ClickHouse/tests/integration/helpers/client.py
2023-08-23 08:08:01 +02:00

289 lines
8.4 KiB
Python

import os
import subprocess as sp
import tempfile
import logging
from threading import Timer
DEFAULT_QUERY_TIMEOUT = 600
class Client:
def __init__(
self,
host,
port=9000,
command="/usr/bin/clickhouse-client",
secure=False,
config=None,
):
self.host = host
self.port = port
self.command = [command]
if os.path.basename(command) == "clickhouse":
self.command.append("client")
if secure:
self.command.append("--secure")
if config is not None:
self.command += ["--config-file", config]
self.command += ["--host", self.host, "--port", str(self.port), "--stacktrace"]
def stacktraces_on_timeout_decorator(func):
def wrap(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except sp.TimeoutExpired:
# I failed to make pytest print stacktraces using print(...) or logging.debug(...), so...
self.get_query_request(
"INSERT INTO TABLE FUNCTION file('stacktraces_on_timeout.txt', 'TSVRaw', 'tn String, tid UInt64, qid String, st String') "
"SELECT thread_name, thread_id, query_id, arrayStringConcat(arrayMap(x -> demangle(addressToSymbol(x)), trace), '\n') AS res FROM system.stack_trace "
"SETTINGS allow_introspection_functions=1",
timeout=60,
).get_answer_and_error()
raise
return wrap
@stacktraces_on_timeout_decorator
def query(
self,
sql,
stdin=None,
timeout=None,
settings=None,
user=None,
password=None,
database=None,
host=None,
ignore_error=False,
query_id=None,
):
return self.get_query_request(
sql,
stdin=stdin,
timeout=timeout,
settings=settings,
user=user,
password=password,
database=database,
host=host,
ignore_error=ignore_error,
query_id=query_id,
).get_answer()
def get_query_request(
self,
sql,
stdin=None,
timeout=None,
settings=None,
user=None,
password=None,
database=None,
host=None,
ignore_error=False,
query_id=None,
):
command = self.command[:]
if stdin is None:
command += ["--multiquery"]
stdin = sql
else:
command += ["--query", sql]
if settings is not None:
for setting, value in settings.items():
command += ["--" + setting, str(value)]
if user is not None:
command += ["--user", user]
if password is not None:
command += ["--password", password]
if database is not None:
command += ["--database", database]
if host is not None:
command += ["--host", host]
if query_id is not None:
command += ["--query_id", query_id]
return CommandRequest(command, stdin, timeout, ignore_error)
@stacktraces_on_timeout_decorator
def query_and_get_error(
self,
sql,
stdin=None,
timeout=None,
settings=None,
user=None,
password=None,
database=None,
query_id=None,
):
return self.get_query_request(
sql,
stdin=stdin,
timeout=timeout,
settings=settings,
user=user,
password=password,
database=database,
query_id=query_id,
).get_error()
@stacktraces_on_timeout_decorator
def query_and_get_answer_with_error(
self,
sql,
stdin=None,
timeout=None,
settings=None,
user=None,
password=None,
database=None,
query_id=None,
):
return self.get_query_request(
sql,
stdin=stdin,
timeout=timeout,
settings=settings,
user=user,
password=password,
database=database,
query_id=query_id,
).get_answer_and_error()
class QueryTimeoutExceedException(Exception):
pass
class QueryRuntimeException(Exception):
def __init__(self, message, returncode, stderr):
super(QueryRuntimeException, self).__init__(message)
self.returncode = returncode
self.stderr = stderr
class CommandRequest:
def __init__(self, command, stdin=None, timeout=None, ignore_error=False):
# Write data to tmp file to avoid PIPEs and execution blocking
stdin_file = tempfile.TemporaryFile(mode="w+")
stdin_file.write(stdin)
stdin_file.seek(0)
self.stdout_file = tempfile.TemporaryFile()
self.stderr_file = tempfile.TemporaryFile()
self.ignore_error = ignore_error
# print " ".join(command)
# we suppress stderror on client becase sometimes thread sanitizer
# can print some debug information there
env = {}
env["ASAN_OPTIONS"] = "use_sigaltstack=0"
env["TSAN_OPTIONS"] = "use_sigaltstack=0 verbosity=0"
self.process = sp.Popen(
command,
stdin=stdin_file,
stdout=self.stdout_file,
stderr=self.stderr_file,
env=env,
universal_newlines=True,
)
self.timer = None
self.process_finished_before_timeout = True
if timeout is not None:
def kill_process():
if self.process.poll() is None:
self.process_finished_before_timeout = False
self.process.kill()
self.timer = Timer(timeout, kill_process)
self.timer.start()
def remove_trash_from_stderr(self, stderr):
# FIXME https://github.com/ClickHouse/ClickHouse/issues/48181
if not stderr:
return stderr
lines = stderr.split("\n")
lines = [
x for x in lines if ("completion_queue" not in x and "Kick failed" not in x)
]
return "\n".join(lines)
def get_answer(self):
self.process.wait(timeout=DEFAULT_QUERY_TIMEOUT)
self.stdout_file.seek(0)
self.stderr_file.seek(0)
stdout = self.stdout_file.read().decode("utf-8", errors="replace")
stderr = self.stderr_file.read().decode("utf-8", errors="replace")
if (
self.timer is not None
and not self.process_finished_before_timeout
and not self.ignore_error
):
logging.debug(f"Timed out. Last stdout:{stdout}, stderr:{stderr}")
raise QueryTimeoutExceedException("Client timed out!")
if (
self.process.returncode != 0 or self.remove_trash_from_stderr(stderr)
) and not self.ignore_error:
raise QueryRuntimeException(
"Client failed! Return code: {}, stderr: {}".format(
self.process.returncode, stderr
),
self.process.returncode,
stderr,
)
return stdout
def get_error(self):
self.process.wait(timeout=DEFAULT_QUERY_TIMEOUT)
self.stdout_file.seek(0)
self.stderr_file.seek(0)
stdout = self.stdout_file.read().decode("utf-8", errors="replace")
stderr = self.stderr_file.read().decode("utf-8", errors="replace")
if (
self.timer is not None
and not self.process_finished_before_timeout
and not self.ignore_error
):
raise QueryTimeoutExceedException("Client timed out!")
if self.process.returncode == 0:
raise QueryRuntimeException(
"Client expected to be failed but succeeded! stdout: {}".format(stdout),
self.process.returncode,
stderr,
)
return stderr
def get_answer_and_error(self):
self.process.wait(timeout=DEFAULT_QUERY_TIMEOUT)
self.stdout_file.seek(0)
self.stderr_file.seek(0)
stdout = self.stdout_file.read().decode("utf-8", errors="replace")
stderr = self.stderr_file.read().decode("utf-8", errors="replace")
if (
self.timer is not None
and not self.process_finished_before_timeout
and not self.ignore_error
):
raise QueryTimeoutExceedException("Client timed out!")
return (stdout, stderr)