From b88be7260f6ce1eda9b949e1aa297eea5a2a110f Mon Sep 17 00:00:00 2001 From: Nikita Fomichev Date: Wed, 3 Jul 2024 13:40:36 +0200 Subject: [PATCH] Tests: Eliminating the global tests queue to prevent clickhouse-test from hanging when a server dies --- tests/clickhouse-test | 121 +++++++++++------------------------------- 1 file changed, 30 insertions(+), 91 deletions(-) diff --git a/tests/clickhouse-test b/tests/clickhouse-test index 36870d59c3a..8e2a256fae2 100755 --- a/tests/clickhouse-test +++ b/tests/clickhouse-test @@ -34,10 +34,8 @@ import urllib.parse # for crc32 import zlib from argparse import ArgumentParser -from contextlib import closing from datetime import datetime, timedelta from errno import ESRCH -from queue import Full from subprocess import PIPE, Popen from time import sleep, time from typing import Dict, List, Optional, Set, Tuple, Union @@ -360,39 +358,6 @@ def clickhouse_execute_json( return rows -class Terminated(KeyboardInterrupt): - pass - - -def signal_handler(sig, frame): - raise Terminated(f"Terminated with {sig} signal") - - -def stop_tests(): - global stop_tests_triggered_lock - global stop_tests_triggered - global restarted_tests - - with stop_tests_triggered_lock: - print("Stopping tests") - if not stop_tests_triggered.is_set(): - stop_tests_triggered.set() - - # materialize multiprocessing.Manager().list() object before - # sending SIGTERM since this object is a proxy, that requires - # communicating with manager thread, but after SIGTERM will be - # send, this thread will die, and you will get - # ConnectionRefusedError error for any access to "restarted_tests" - # variable. - restarted_tests = [*restarted_tests] - - # send signal to all processes in group to avoid hung check triggering - # (to avoid terminating clickhouse-test itself, the signal should be ignored) - signal.signal(signal.SIGTERM, signal.SIG_IGN) - os.killpg(os.getpgid(os.getpid()), signal.SIGTERM) - signal.signal(signal.SIGTERM, signal.SIG_DFL) - - def get_db_engine(args, database_name): if args.replicated_database: return f" ON CLUSTER test_cluster_database_replicated \ @@ -2061,13 +2026,18 @@ class TestSuite: stop_time = None exit_code = None server_died = None -stop_tests_triggered_lock = None -stop_tests_triggered = None -queue = None multiprocessing_manager = None restarted_tests = None +class ServerDied(Exception): + pass + + +class GlobalTimeout(Exception): + pass + + def run_tests_array(all_tests_with_params: Tuple[List[str], int, TestSuite]): all_tests, num_tests, test_suite = all_tests_with_params global stop_time @@ -2122,24 +2092,17 @@ def run_tests_array(all_tests_with_params: Tuple[List[str], int, TestSuite]): print(f"\nRunning {about}{num_tests} {test_suite.suite} tests ({proc_name}).\n") while True: - if is_concurrent: - case = queue.get(timeout=args.timeout * 1.1) - if not case: - break + if all_tests: + case = all_tests.pop(0) else: - if all_tests: - case = all_tests.pop(0) - else: - break + break if server_died.is_set(): - stop_tests() - break + raise ServerDied("Server died") if stop_time and time() > stop_time: print("\nStop tests run because global time limit is exceeded.\n") - stop_tests() - break + raise GlobalTimeout("Stop tests run because global time limit is exceeded") test_case = TestCase(test_suite, case, args, is_concurrent) @@ -2182,18 +2145,15 @@ def run_tests_array(all_tests_with_params: Tuple[List[str], int, TestSuite]): failures_chain += 1 if test_result.reason == FailureReason.SERVER_DIED: server_died.set() - stop_tests() elif test_result.status == TestStatus.SKIPPED: skipped_total += 1 except KeyboardInterrupt as e: print(colored("Break tests execution", args, "red")) - stop_tests() raise e if failures_chain >= args.max_failures_chain: - stop_tests() - break + raise ServerDied("Max failures chain") if failures_total > 0: print( @@ -2390,7 +2350,7 @@ def extract_key(key: str) -> str: )[1] -def do_run_tests(jobs, test_suite: TestSuite, parallel): +def do_run_tests(jobs, test_suite: TestSuite): if jobs > 1 and len(test_suite.parallel_tests) > 0: print( "Found", @@ -2399,19 +2359,8 @@ def do_run_tests(jobs, test_suite: TestSuite, parallel): len(test_suite.sequential_tests), "sequential tests", ) - run_n, run_total = parallel.split("/") - run_n = float(run_n) - run_total = float(run_total) tests_n = len(test_suite.parallel_tests) - run_total = min(run_total, tests_n) - jobs = min(jobs, tests_n) - run_total = max(jobs, run_total) - - batch_size = max(1, len(test_suite.parallel_tests) // jobs) - parallel_tests_array = [] - for _ in range(jobs): - parallel_tests_array.append((None, batch_size, test_suite)) # If we don't do random shuffling then there will be always # nearly the same groups of test suites running concurrently. @@ -2424,25 +2373,21 @@ def do_run_tests(jobs, test_suite: TestSuite, parallel): # of failures will be nearly the same for all tests from the group. random.shuffle(test_suite.parallel_tests) + batch_size = max(1, len(test_suite.parallel_tests) // jobs) + parallel_tests_array = [] + for job in range(jobs): + range_ = job * batch_size, job * batch_size + batch_size + batch = test_suite.parallel_tests[range_[0] : range_[1]] + parallel_tests_array.append((batch, batch_size, test_suite)) + try: - with closing(multiprocessing.Pool(processes=jobs)) as pool: - pool.map_async(run_tests_array, parallel_tests_array) - - for suit in test_suite.parallel_tests: - queue.put(suit, timeout=args.timeout * 1.1) - - for _ in range(jobs): - queue.put(None, timeout=args.timeout * 1.1) - - queue.close() - except Full: - print( - "Couldn't put test to the queue within timeout. Server probably hung." - ) - print_stacktraces() - queue.close() - - pool.join() + with multiprocessing.Pool(processes=jobs) as pool: + future = pool.map_async(run_tests_array, parallel_tests_array) + future.wait() + finally: + pool.terminate() + pool.close() + pool.join() run_tests_array( (test_suite.sequential_tests, len(test_suite.sequential_tests), test_suite) @@ -2807,7 +2752,7 @@ def main(args): test_suite.cloud_skip_list = cloud_skip_list test_suite.private_skip_list = private_skip_list - total_tests_run += do_run_tests(args.jobs, test_suite, args.parallel) + total_tests_run += do_run_tests(args.jobs, test_suite) if server_died.is_set(): exit_code.value = 1 @@ -3268,9 +3213,6 @@ if __name__ == "__main__": stop_time = None exit_code = multiprocessing.Value("i", 0) server_died = multiprocessing.Event() - stop_tests_triggered_lock = multiprocessing.Lock() - stop_tests_triggered = multiprocessing.Event() - queue = multiprocessing.Queue(maxsize=1) multiprocessing_manager = multiprocessing.Manager() restarted_tests = multiprocessing_manager.list() @@ -3278,9 +3220,6 @@ if __name__ == "__main__": # infinite tests processes left # (new process group is required to avoid killing some parent processes) os.setpgid(0, 0) - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGHUP, signal_handler) try: args = parse_args()