Tests: Eliminating the global tests queue to prevent clickhouse-test from hanging when a server dies

This commit is contained in:
Nikita Fomichev 2024-07-03 13:40:36 +02:00
parent fe43ea27d2
commit b88be7260f

View File

@ -34,10 +34,8 @@ import urllib.parse
# for crc32 # for crc32
import zlib import zlib
from argparse import ArgumentParser from argparse import ArgumentParser
from contextlib import closing
from datetime import datetime, timedelta from datetime import datetime, timedelta
from errno import ESRCH from errno import ESRCH
from queue import Full
from subprocess import PIPE, Popen from subprocess import PIPE, Popen
from time import sleep, time from time import sleep, time
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Dict, List, Optional, Set, Tuple, Union
@ -360,39 +358,6 @@ def clickhouse_execute_json(
return rows return rows
class Terminated(KeyboardInterrupt):
pass
def signal_handler(sig, frame):
raise Terminated(f"Terminated with {sig} signal")
def stop_tests():
global stop_tests_triggered_lock
global stop_tests_triggered
global restarted_tests
with stop_tests_triggered_lock:
print("Stopping tests")
if not stop_tests_triggered.is_set():
stop_tests_triggered.set()
# materialize multiprocessing.Manager().list() object before
# sending SIGTERM since this object is a proxy, that requires
# communicating with manager thread, but after SIGTERM will be
# send, this thread will die, and you will get
# ConnectionRefusedError error for any access to "restarted_tests"
# variable.
restarted_tests = [*restarted_tests]
# send signal to all processes in group to avoid hung check triggering
# (to avoid terminating clickhouse-test itself, the signal should be ignored)
signal.signal(signal.SIGTERM, signal.SIG_IGN)
os.killpg(os.getpgid(os.getpid()), signal.SIGTERM)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
def get_db_engine(args, database_name): def get_db_engine(args, database_name):
if args.replicated_database: if args.replicated_database:
return f" ON CLUSTER test_cluster_database_replicated \ return f" ON CLUSTER test_cluster_database_replicated \
@ -2061,13 +2026,18 @@ class TestSuite:
stop_time = None stop_time = None
exit_code = None exit_code = None
server_died = None server_died = None
stop_tests_triggered_lock = None
stop_tests_triggered = None
queue = None
multiprocessing_manager = None multiprocessing_manager = None
restarted_tests = None restarted_tests = None
class ServerDied(Exception):
pass
class GlobalTimeout(Exception):
pass
def run_tests_array(all_tests_with_params: Tuple[List[str], int, TestSuite]): def run_tests_array(all_tests_with_params: Tuple[List[str], int, TestSuite]):
all_tests, num_tests, test_suite = all_tests_with_params all_tests, num_tests, test_suite = all_tests_with_params
global stop_time global stop_time
@ -2122,24 +2092,17 @@ def run_tests_array(all_tests_with_params: Tuple[List[str], int, TestSuite]):
print(f"\nRunning {about}{num_tests} {test_suite.suite} tests ({proc_name}).\n") print(f"\nRunning {about}{num_tests} {test_suite.suite} tests ({proc_name}).\n")
while True: while True:
if is_concurrent:
case = queue.get(timeout=args.timeout * 1.1)
if not case:
break
else:
if all_tests: if all_tests:
case = all_tests.pop(0) case = all_tests.pop(0)
else: else:
break break
if server_died.is_set(): if server_died.is_set():
stop_tests() raise ServerDied("Server died")
break
if stop_time and time() > stop_time: if stop_time and time() > stop_time:
print("\nStop tests run because global time limit is exceeded.\n") print("\nStop tests run because global time limit is exceeded.\n")
stop_tests() raise GlobalTimeout("Stop tests run because global time limit is exceeded")
break
test_case = TestCase(test_suite, case, args, is_concurrent) test_case = TestCase(test_suite, case, args, is_concurrent)
@ -2182,18 +2145,15 @@ def run_tests_array(all_tests_with_params: Tuple[List[str], int, TestSuite]):
failures_chain += 1 failures_chain += 1
if test_result.reason == FailureReason.SERVER_DIED: if test_result.reason == FailureReason.SERVER_DIED:
server_died.set() server_died.set()
stop_tests()
elif test_result.status == TestStatus.SKIPPED: elif test_result.status == TestStatus.SKIPPED:
skipped_total += 1 skipped_total += 1
except KeyboardInterrupt as e: except KeyboardInterrupt as e:
print(colored("Break tests execution", args, "red")) print(colored("Break tests execution", args, "red"))
stop_tests()
raise e raise e
if failures_chain >= args.max_failures_chain: if failures_chain >= args.max_failures_chain:
stop_tests() raise ServerDied("Max failures chain")
break
if failures_total > 0: if failures_total > 0:
print( print(
@ -2390,7 +2350,7 @@ def extract_key(key: str) -> str:
)[1] )[1]
def do_run_tests(jobs, test_suite: TestSuite, parallel): def do_run_tests(jobs, test_suite: TestSuite):
if jobs > 1 and len(test_suite.parallel_tests) > 0: if jobs > 1 and len(test_suite.parallel_tests) > 0:
print( print(
"Found", "Found",
@ -2399,19 +2359,8 @@ def do_run_tests(jobs, test_suite: TestSuite, parallel):
len(test_suite.sequential_tests), len(test_suite.sequential_tests),
"sequential tests", "sequential tests",
) )
run_n, run_total = parallel.split("/")
run_n = float(run_n)
run_total = float(run_total)
tests_n = len(test_suite.parallel_tests) tests_n = len(test_suite.parallel_tests)
run_total = min(run_total, tests_n)
jobs = min(jobs, tests_n) jobs = min(jobs, tests_n)
run_total = max(jobs, run_total)
batch_size = max(1, len(test_suite.parallel_tests) // jobs)
parallel_tests_array = []
for _ in range(jobs):
parallel_tests_array.append((None, batch_size, test_suite))
# If we don't do random shuffling then there will be always # If we don't do random shuffling then there will be always
# nearly the same groups of test suites running concurrently. # nearly the same groups of test suites running concurrently.
@ -2424,24 +2373,20 @@ def do_run_tests(jobs, test_suite: TestSuite, parallel):
# of failures will be nearly the same for all tests from the group. # of failures will be nearly the same for all tests from the group.
random.shuffle(test_suite.parallel_tests) random.shuffle(test_suite.parallel_tests)
batch_size = max(1, len(test_suite.parallel_tests) // jobs)
parallel_tests_array = []
for job in range(jobs):
range_ = job * batch_size, job * batch_size + batch_size
batch = test_suite.parallel_tests[range_[0] : range_[1]]
parallel_tests_array.append((batch, batch_size, test_suite))
try: try:
with closing(multiprocessing.Pool(processes=jobs)) as pool: with multiprocessing.Pool(processes=jobs) as pool:
pool.map_async(run_tests_array, parallel_tests_array) future = pool.map_async(run_tests_array, parallel_tests_array)
future.wait()
for suit in test_suite.parallel_tests: finally:
queue.put(suit, timeout=args.timeout * 1.1) pool.terminate()
pool.close()
for _ in range(jobs):
queue.put(None, timeout=args.timeout * 1.1)
queue.close()
except Full:
print(
"Couldn't put test to the queue within timeout. Server probably hung."
)
print_stacktraces()
queue.close()
pool.join() pool.join()
run_tests_array( run_tests_array(
@ -2807,7 +2752,7 @@ def main(args):
test_suite.cloud_skip_list = cloud_skip_list test_suite.cloud_skip_list = cloud_skip_list
test_suite.private_skip_list = private_skip_list test_suite.private_skip_list = private_skip_list
total_tests_run += do_run_tests(args.jobs, test_suite, args.parallel) total_tests_run += do_run_tests(args.jobs, test_suite)
if server_died.is_set(): if server_died.is_set():
exit_code.value = 1 exit_code.value = 1
@ -3268,9 +3213,6 @@ if __name__ == "__main__":
stop_time = None stop_time = None
exit_code = multiprocessing.Value("i", 0) exit_code = multiprocessing.Value("i", 0)
server_died = multiprocessing.Event() server_died = multiprocessing.Event()
stop_tests_triggered_lock = multiprocessing.Lock()
stop_tests_triggered = multiprocessing.Event()
queue = multiprocessing.Queue(maxsize=1)
multiprocessing_manager = multiprocessing.Manager() multiprocessing_manager = multiprocessing.Manager()
restarted_tests = multiprocessing_manager.list() restarted_tests = multiprocessing_manager.list()
@ -3278,9 +3220,6 @@ if __name__ == "__main__":
# infinite tests processes left # infinite tests processes left
# (new process group is required to avoid killing some parent processes) # (new process group is required to avoid killing some parent processes)
os.setpgid(0, 0) os.setpgid(0, 0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGHUP, signal_handler)
try: try:
args = parse_args() args = parse_args()