Tests: Eliminating the global tests queue to prevent clickhouse-test from hanging when a server dies

This commit is contained in:
Nikita Fomichev 2024-07-03 13:40:36 +02:00
parent fe43ea27d2
commit b88be7260f

View File

@ -34,10 +34,8 @@ import urllib.parse
# for crc32
import zlib
from argparse import ArgumentParser
from contextlib import closing
from datetime import datetime, timedelta
from errno import ESRCH
from queue import Full
from subprocess import PIPE, Popen
from time import sleep, time
from typing import Dict, List, Optional, Set, Tuple, Union
@ -360,39 +358,6 @@ def clickhouse_execute_json(
return rows
class Terminated(KeyboardInterrupt):
pass
def signal_handler(sig, frame):
raise Terminated(f"Terminated with {sig} signal")
def stop_tests():
global stop_tests_triggered_lock
global stop_tests_triggered
global restarted_tests
with stop_tests_triggered_lock:
print("Stopping tests")
if not stop_tests_triggered.is_set():
stop_tests_triggered.set()
# materialize multiprocessing.Manager().list() object before
# sending SIGTERM since this object is a proxy, that requires
# communicating with manager thread, but after SIGTERM will be
# send, this thread will die, and you will get
# ConnectionRefusedError error for any access to "restarted_tests"
# variable.
restarted_tests = [*restarted_tests]
# send signal to all processes in group to avoid hung check triggering
# (to avoid terminating clickhouse-test itself, the signal should be ignored)
signal.signal(signal.SIGTERM, signal.SIG_IGN)
os.killpg(os.getpgid(os.getpid()), signal.SIGTERM)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
def get_db_engine(args, database_name):
if args.replicated_database:
return f" ON CLUSTER test_cluster_database_replicated \
@ -2061,13 +2026,18 @@ class TestSuite:
stop_time = None
exit_code = None
server_died = None
stop_tests_triggered_lock = None
stop_tests_triggered = None
queue = None
multiprocessing_manager = None
restarted_tests = None
class ServerDied(Exception):
pass
class GlobalTimeout(Exception):
pass
def run_tests_array(all_tests_with_params: Tuple[List[str], int, TestSuite]):
all_tests, num_tests, test_suite = all_tests_with_params
global stop_time
@ -2122,24 +2092,17 @@ def run_tests_array(all_tests_with_params: Tuple[List[str], int, TestSuite]):
print(f"\nRunning {about}{num_tests} {test_suite.suite} tests ({proc_name}).\n")
while True:
if is_concurrent:
case = queue.get(timeout=args.timeout * 1.1)
if not case:
break
else:
if all_tests:
case = all_tests.pop(0)
else:
break
if server_died.is_set():
stop_tests()
break
raise ServerDied("Server died")
if stop_time and time() > stop_time:
print("\nStop tests run because global time limit is exceeded.\n")
stop_tests()
break
raise GlobalTimeout("Stop tests run because global time limit is exceeded")
test_case = TestCase(test_suite, case, args, is_concurrent)
@ -2182,18 +2145,15 @@ def run_tests_array(all_tests_with_params: Tuple[List[str], int, TestSuite]):
failures_chain += 1
if test_result.reason == FailureReason.SERVER_DIED:
server_died.set()
stop_tests()
elif test_result.status == TestStatus.SKIPPED:
skipped_total += 1
except KeyboardInterrupt as e:
print(colored("Break tests execution", args, "red"))
stop_tests()
raise e
if failures_chain >= args.max_failures_chain:
stop_tests()
break
raise ServerDied("Max failures chain")
if failures_total > 0:
print(
@ -2390,7 +2350,7 @@ def extract_key(key: str) -> str:
)[1]
def do_run_tests(jobs, test_suite: TestSuite, parallel):
def do_run_tests(jobs, test_suite: TestSuite):
if jobs > 1 and len(test_suite.parallel_tests) > 0:
print(
"Found",
@ -2399,19 +2359,8 @@ def do_run_tests(jobs, test_suite: TestSuite, parallel):
len(test_suite.sequential_tests),
"sequential tests",
)
run_n, run_total = parallel.split("/")
run_n = float(run_n)
run_total = float(run_total)
tests_n = len(test_suite.parallel_tests)
run_total = min(run_total, tests_n)
jobs = min(jobs, tests_n)
run_total = max(jobs, run_total)
batch_size = max(1, len(test_suite.parallel_tests) // jobs)
parallel_tests_array = []
for _ in range(jobs):
parallel_tests_array.append((None, batch_size, test_suite))
# If we don't do random shuffling then there will be always
# nearly the same groups of test suites running concurrently.
@ -2424,24 +2373,20 @@ def do_run_tests(jobs, test_suite: TestSuite, parallel):
# of failures will be nearly the same for all tests from the group.
random.shuffle(test_suite.parallel_tests)
batch_size = max(1, len(test_suite.parallel_tests) // jobs)
parallel_tests_array = []
for job in range(jobs):
range_ = job * batch_size, job * batch_size + batch_size
batch = test_suite.parallel_tests[range_[0] : range_[1]]
parallel_tests_array.append((batch, batch_size, test_suite))
try:
with closing(multiprocessing.Pool(processes=jobs)) as pool:
pool.map_async(run_tests_array, parallel_tests_array)
for suit in test_suite.parallel_tests:
queue.put(suit, timeout=args.timeout * 1.1)
for _ in range(jobs):
queue.put(None, timeout=args.timeout * 1.1)
queue.close()
except Full:
print(
"Couldn't put test to the queue within timeout. Server probably hung."
)
print_stacktraces()
queue.close()
with multiprocessing.Pool(processes=jobs) as pool:
future = pool.map_async(run_tests_array, parallel_tests_array)
future.wait()
finally:
pool.terminate()
pool.close()
pool.join()
run_tests_array(
@ -2807,7 +2752,7 @@ def main(args):
test_suite.cloud_skip_list = cloud_skip_list
test_suite.private_skip_list = private_skip_list
total_tests_run += do_run_tests(args.jobs, test_suite, args.parallel)
total_tests_run += do_run_tests(args.jobs, test_suite)
if server_died.is_set():
exit_code.value = 1
@ -3268,9 +3213,6 @@ if __name__ == "__main__":
stop_time = None
exit_code = multiprocessing.Value("i", 0)
server_died = multiprocessing.Event()
stop_tests_triggered_lock = multiprocessing.Lock()
stop_tests_triggered = multiprocessing.Event()
queue = multiprocessing.Queue(maxsize=1)
multiprocessing_manager = multiprocessing.Manager()
restarted_tests = multiprocessing_manager.list()
@ -3278,9 +3220,6 @@ if __name__ == "__main__":
# infinite tests processes left
# (new process group is required to avoid killing some parent processes)
os.setpgid(0, 0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGHUP, signal_handler)
try:
args = parse_args()