clickhouse-test: remove global variables to support 'spawn' fork method in macos

This commit is contained in:
Nikita Fomichev 2024-12-17 14:18:25 +01:00
parent 1304b5bad6
commit c4cbeb0428

View File

@ -35,7 +35,7 @@ import urllib.parse
# for crc32
import zlib
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from ast import literal_eval as make_tuple
from contextlib import redirect_stdout
from datetime import datetime, timedelta
@ -1174,6 +1174,7 @@ class TestCase:
def __init__(self, suite, case: str, args, is_concurrent: bool):
self.case: str = case # case file name
self.args: Namespace = args
self.tags: Set[str] = suite.all_tags[case] if case in suite.all_tags else set()
self.random_settings_limits = (
suite.all_random_settings_limits[case]
@ -1246,7 +1247,7 @@ class TestCase:
# should skip test, should increment skipped_total, skip reason
def should_skip_test(self, suite) -> Optional[FailureReason]:
tags = self.tags
args = self.args
if tags and ("disabled" in tags) and not args.disabled:
return FailureReason.DISABLED
@ -1564,13 +1565,13 @@ class TestCase:
return TestResult(self.name, TestStatus.OK, None, total_time, description)
@staticmethod
def print_test_time(test_time) -> str:
if args.print_time:
def print_test_time(self, test_time) -> str:
if self.args.print_time:
return f" {test_time:.2f} sec."
return ""
def process_result(self, result: TestResult, messages):
args = self.args
description_full = messages[result.status]
description_full += self.print_test_time(result.total_time)
if result.reason is not None:
@ -1628,11 +1629,10 @@ class TestCase:
result.description = description_full
return result
@staticmethod
def send_test_name_failed(suite: str, case: str):
def send_test_name_failed(self, suite: str, case: str):
pid = os.getpid()
clickhouse_execute(
args,
self.args,
f"SELECT 'Running test {suite}/{case} from pid={pid}'",
retry_error_codes=True,
)
@ -2261,13 +2261,6 @@ class TestSuite:
return TestSuite(args, suite_path, suite_tmp_path, suite)
stop_time = None
exit_code = None
server_died = None
multiprocessing_manager = None
restarted_tests = None
class ServerDied(Exception):
pass
@ -2276,17 +2269,19 @@ class GlobalTimeout(Exception):
pass
def run_tests_array(all_tests_with_params: Tuple[List[str], int, TestSuite, bool]):
def run_tests_array(
all_tests_with_params: Tuple[List[str], int, TestSuite, bool, Namespace]
):
(
all_tests,
num_tests,
test_suite,
is_concurrent,
args,
exit_code,
server_died,
restarted_tests,
) = all_tests_with_params
global stop_time
global exit_code
global server_died
global restarted_tests
OP_SQUARE_BRACKET = colored("[", args, attrs=["bold"])
CL_SQUARE_BRACKET = colored("]", args, attrs=["bold"])
@ -2345,7 +2340,7 @@ def run_tests_array(all_tests_with_params: Tuple[List[str], int, TestSuite, bool
stop_tests()
raise ServerDied("Server died")
if stop_time and time() > stop_time:
if args.stop_time and time() > args.stop_time:
print("\nStop tests run because global time limit is exceeded.\n")
stop_tests()
raise GlobalTimeout("Stop tests run because global time limit is exceeded")
@ -2624,11 +2619,13 @@ def extract_key(key: str) -> str:
)[1]
def run_tests_process(*args, **kwargs):
return run_tests_array(*args, **kwargs)
def run_tests_process(*args_, **kwargs):
return run_tests_array(*args_, **kwargs)
def do_run_tests(jobs, test_suite: TestSuite):
def do_run_tests(
jobs, test_suite: TestSuite, args, exit_code, restarted_tests, server_died
):
print(
"Found",
len(test_suite.parallel_tests),
@ -2636,6 +2633,7 @@ def do_run_tests(jobs, test_suite: TestSuite):
len(test_suite.sequential_tests),
"sequential tests",
)
if test_suite.parallel_tests:
tests_n = len(test_suite.parallel_tests)
jobs = min(jobs, tests_n)
@ -2661,7 +2659,18 @@ def do_run_tests(jobs, test_suite: TestSuite):
for _ in range(jobs):
process = multiprocessing.Process(
target=run_tests_process,
args=((parallel_tests, batch_size, test_suite, True),),
args=(
(
parallel_tests,
batch_size,
test_suite,
True,
args,
exit_code,
server_died,
restarted_tests,
),
),
)
processes.append(process)
process.start()
@ -2694,7 +2703,6 @@ def do_run_tests(jobs, test_suite: TestSuite):
for p in processes[:]:
if not p.is_alive():
processes.remove(p)
if test_suite.sequential_tests:
run_tests_array(
(
@ -2702,6 +2710,10 @@ def do_run_tests(jobs, test_suite: TestSuite):
len(test_suite.sequential_tests),
test_suite,
False,
args,
exit_code,
server_died,
restarted_tests,
)
)
@ -2945,11 +2957,10 @@ def try_get_skip_list(base_dir, name):
def main(args):
global server_died
global stop_time
global exit_code
global server_logs_level
global restarted_tests
exit_code = multiprocessing.Value("i", 0)
server_died = multiprocessing.Event()
multiprocessing_manager = multiprocessing.Manager()
restarted_tests = multiprocessing_manager.list()
if not check_server_started(args):
msg = "Server is not responding. Cannot execute 'SELECT 1' query."
@ -2989,8 +3000,9 @@ def main(args):
os.environ.setdefault("CLICKHOUSE_CLIENT_SERVER_LOGS_LEVEL", server_logs_level)
# This code is bad as the time is not monotonic
args.stop_time = None
if args.global_time_limit:
stop_time = time() + args.global_time_limit
args.stop_time = time() + args.global_time_limit
if args.zookeeper is None:
args.zookeeper = True
@ -3062,7 +3074,9 @@ def main(args):
test_suite.cloud_skip_list = cloud_skip_list
test_suite.private_skip_list = private_skip_list
total_tests_run += do_run_tests(args.jobs, test_suite)
total_tests_run += do_run_tests(
args.jobs, test_suite, args, exit_code, restarted_tests, server_died
)
if server_died.is_set():
exit_code.value = 1
@ -3567,12 +3581,8 @@ if __name__ == "__main__":
# infinite tests processes left
# (new process group is required to avoid killing some parent processes)
os.setpgid(0, 0)
stop_time = None
exit_code = multiprocessing.Value("i", 0)
server_died = multiprocessing.Event()
multiprocessing_manager = multiprocessing.Manager()
restarted_tests = multiprocessing_manager.list()
# TODO set 'fork' for all CI configurations and 'spawn' for aarch64
multiprocessing.set_start_method("spawn")
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)