add test with broken pipe

This commit is contained in:
Sema Checherinda 2023-07-27 23:44:32 +04:00
parent 7d430b8037
commit 4629ab1df1
2 changed files with 126 additions and 31 deletions

View File

@ -37,9 +37,7 @@ class MockControl:
) )
assert response == "OK", response assert response == "OK", response
def setup_action( def setup_action(self, when, count=None, after=None, action=None, action_args=None):
self, when, count=None, after=None, action="error_500", action_args=None
):
url = f"http://localhost:{self._port}/mock_settings/{when}?nothing=1" url = f"http://localhost:{self._port}/mock_settings/{when}?nothing=1"
if count is not None: if count is not None:
@ -128,8 +126,14 @@ class MockControl:
class _ServerRuntime: class _ServerRuntime:
class SlowPut: class SlowPut:
def __init__( def __init__(
self, probability_=None, timeout_=None, minimal_length_=None, count_=None self,
lock,
probability_=None,
timeout_=None,
minimal_length_=None,
count_=None,
): ):
self.lock = lock
self.probability = probability_ if probability_ is not None else 1 self.probability = probability_ if probability_ is not None else 1
self.timeout = timeout_ if timeout_ is not None else 0.1 self.timeout = timeout_ if timeout_ is not None else 0.1
self.minimal_length = minimal_length_ if minimal_length_ is not None else 0 self.minimal_length = minimal_length_ if minimal_length_ is not None else 0
@ -144,14 +148,15 @@ class _ServerRuntime:
) )
def get_timeout(self, content_length): def get_timeout(self, content_length):
if content_length > self.minimal_length: with self.lock:
if self.count > 0: if content_length > self.minimal_length:
if ( if self.count > 0:
_runtime.slow_put.probability == 1 if (
or random.random() <= _runtime.slow_put.probability _runtime.slow_put.probability == 1
): or random.random() <= _runtime.slow_put.probability
self.count -= 1 ):
return _runtime.slow_put.timeout self.count -= 1
return _runtime.slow_put.timeout
return None return None
class Expected500ErrorAction: class Expected500ErrorAction:
@ -199,29 +204,48 @@ class _ServerRuntime:
) )
request_handler.connection.close() request_handler.connection.close()
class BrokenPipeAction:
def inject_error(self, request_handler):
# partial read
self.rfile.read(50)
time.sleep(1)
request_handler.connection.setsockopt(
socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
)
request_handler.connection.close()
class ConnectionRefusedAction(RedirectAction): class ConnectionRefusedAction(RedirectAction):
pass pass
class CountAfter: class CountAfter:
def __init__(self, count_=None, after_=None, action_=None, action_args_=[]): def __init__(
self, lock, count_=None, after_=None, action_=None, action_args_=[]
):
self.lock = lock
self.count = count_ if count_ is not None else INF_COUNT self.count = count_ if count_ is not None else INF_COUNT
self.after = after_ if after_ is not None else 0 self.after = after_ if after_ is not None else 0
self.action = action_ self.action = action_
self.action_args = action_args_ self.action_args = action_args_
if self.action == "connection_refused": if self.action == "connection_refused":
self.error_handler = _ServerRuntime.ConnectionRefusedAction() self.error_handler = _ServerRuntime.ConnectionRefusedAction()
elif self.action == "connection_reset_by_peer": elif self.action == "connection_reset_by_peer":
self.error_handler = _ServerRuntime.ConnectionResetByPeerAction( self.error_handler = _ServerRuntime.ConnectionResetByPeerAction(
*self.action_args *self.action_args
) )
elif self.action == "broken_pipe":
self.error_handler = _ServerRuntime.BrokenPipeAction()
elif self.action == "redirect_to": elif self.action == "redirect_to":
self.error_handler = _ServerRuntime.RedirectAction(*self.action_args) self.error_handler = _ServerRuntime.RedirectAction(*self.action_args)
else: else:
self.error_handler = _ServerRuntime.Expected500ErrorAction() self.error_handler = _ServerRuntime.Expected500ErrorAction()
@staticmethod @staticmethod
def from_cgi_params(params): def from_cgi_params(lock, params):
return _ServerRuntime.CountAfter( return _ServerRuntime.CountAfter(
lock=lock,
count_=_and_then(params.get("count", [None])[0], int), count_=_and_then(params.get("count", [None])[0], int),
after_=_and_then(params.get("after", [None])[0], int), after_=_and_then(params.get("after", [None])[0], int),
action_=params.get("action", [None])[0], action_=params.get("action", [None])[0],
@ -232,13 +256,14 @@ class _ServerRuntime:
return f"count:{self.count} after:{self.after} action:{self.action} action_args:{self.action_args}" return f"count:{self.count} after:{self.after} action:{self.action} action_args:{self.action_args}"
def has_effect(self): def has_effect(self):
if self.after: with self.lock:
self.after -= 1 if self.after:
if self.after == 0: self.after -= 1
if self.count: if self.after == 0:
self.count -= 1 if self.count:
return True self.count -= 1
return False return True
return False
def inject_error(self, request_handler): def inject_error(self, request_handler):
self.error_handler.inject_error(request_handler) self.error_handler.inject_error(request_handler)
@ -397,14 +422,16 @@ class RequestHandler(http.server.BaseHTTPRequestHandler):
if path[1] == "at_part_upload": if path[1] == "at_part_upload":
params = urllib.parse.parse_qs(parts.query, keep_blank_values=False) params = urllib.parse.parse_qs(parts.query, keep_blank_values=False)
_runtime.at_part_upload = _ServerRuntime.CountAfter.from_cgi_params(params) _runtime.at_part_upload = _ServerRuntime.CountAfter.from_cgi_params(
_runtime.lock, params
)
self.log_message("set at_part_upload %s", _runtime.at_part_upload) self.log_message("set at_part_upload %s", _runtime.at_part_upload)
return self._ok() return self._ok()
if path[1] == "at_object_upload": if path[1] == "at_object_upload":
params = urllib.parse.parse_qs(parts.query, keep_blank_values=False) params = urllib.parse.parse_qs(parts.query, keep_blank_values=False)
_runtime.at_object_upload = _ServerRuntime.CountAfter.from_cgi_params( _runtime.at_object_upload = _ServerRuntime.CountAfter.from_cgi_params(
params _runtime.lock, params
) )
self.log_message("set at_object_upload %s", _runtime.at_object_upload) self.log_message("set at_object_upload %s", _runtime.at_object_upload)
return self._ok() return self._ok()
@ -420,6 +447,7 @@ class RequestHandler(http.server.BaseHTTPRequestHandler):
if path[1] == "slow_put": if path[1] == "slow_put":
params = urllib.parse.parse_qs(parts.query, keep_blank_values=False) params = urllib.parse.parse_qs(parts.query, keep_blank_values=False)
_runtime.slow_put = _ServerRuntime.SlowPut( _runtime.slow_put = _ServerRuntime.SlowPut(
lock=_runtime.lock,
minimal_length_=_and_then(params.get("minimal_length", [None])[0], int), minimal_length_=_and_then(params.get("minimal_length", [None])[0], int),
probability_=_and_then(params.get("probability", [None])[0], float), probability_=_and_then(params.get("probability", [None])[0], float),
timeout_=_and_then(params.get("timeout", [None])[0], float), timeout_=_and_then(params.get("timeout", [None])[0], float),
@ -436,7 +464,7 @@ class RequestHandler(http.server.BaseHTTPRequestHandler):
if path[1] == "at_create_multi_part_upload": if path[1] == "at_create_multi_part_upload":
params = urllib.parse.parse_qs(parts.query, keep_blank_values=False) params = urllib.parse.parse_qs(parts.query, keep_blank_values=False)
_runtime.at_create_multi_part_upload = ( _runtime.at_create_multi_part_upload = (
_ServerRuntime.CountAfter.from_cgi_params(params) _ServerRuntime.CountAfter.from_cgi_params(_runtime.lock, params)
) )
self.log_message( self.log_message(
"set at_create_multi_part_upload %s", "set at_create_multi_part_upload %s",
@ -477,7 +505,7 @@ class RequestHandler(http.server.BaseHTTPRequestHandler):
if upload_id is not None: if upload_id is not None:
if _runtime.at_part_upload is not None: if _runtime.at_part_upload is not None:
self.log_message( self.log_message(
"put error_at_object_upload %s, %s, %s", "put at_part_upload %s, %s, %s",
_runtime.at_part_upload, _runtime.at_part_upload,
upload_id, upload_id,
parts, parts,
@ -492,7 +520,7 @@ class RequestHandler(http.server.BaseHTTPRequestHandler):
if _runtime.at_object_upload is not None: if _runtime.at_object_upload is not None:
if _runtime.at_object_upload.has_effect(): if _runtime.at_object_upload.has_effect():
self.log_message( self.log_message(
"put error_at_object_upload %s, %s, %s", "put error_at_object_upload %s, %s",
_runtime.at_object_upload, _runtime.at_object_upload,
parts, parts,
) )

View File

@ -41,11 +41,6 @@ def broken_s3(init_broken_s3):
yield init_broken_s3 yield init_broken_s3
@pytest.fixture(scope="module")
def init_connection_reset_by_peer(cluster):
yield start_s3_mock(cluster, "connection_reset_by_peer", "8084")
def test_upload_after_check_works(cluster, broken_s3): def test_upload_after_check_works(cluster, broken_s3):
node = cluster.instances["node"] node = cluster.instances["node"]
@ -397,3 +392,75 @@ def test_when_s3_connection_reset_by_peer_at_create_mpu_retried(
or "DB::Exception: Poco::Exception. Code: 1000, e.code() = 104, Connection reset by peer" or "DB::Exception: Poco::Exception. Code: 1000, e.code() = 104, Connection reset by peer"
in error in error
), error ), error
def test_when_s3_broken_pipe_at_upload_is_retried(cluster, broken_s3):
node = cluster.instances["node"]
broken_s3.setup_fake_multpartuploads()
broken_s3.setup_at_part_upload(
count=3,
after=2,
action="broken_pipe",
)
insert_query_id = f"TEST_WHEN_S3_BROKEN_PIPE_AT_UPLOAD"
node.query(
f"""
INSERT INTO
TABLE FUNCTION s3(
'http://resolver:8083/root/data/test_when_s3_broken_pipe_at_upload_is_retried',
'minio', 'minio123',
'CSV', auto, 'none'
)
SELECT
*
FROM system.numbers
LIMIT 1000000
SETTINGS
s3_max_single_part_upload_size=100,
s3_min_upload_part_size=1000000,
s3_check_objects_after_upload=0
""",
query_id=insert_query_id,
)
count_create_multi_part_uploads, count_upload_parts, count_s3_errors = get_counters(
node, insert_query_id, log_type="QueryFinish"
)
assert count_create_multi_part_uploads == 1
assert count_upload_parts == 7
assert count_s3_errors == 3
broken_s3.setup_at_part_upload(
count=1000,
after=2,
action="broken_pipe",
)
insert_query_id = f"TEST_WHEN_S3_BROKEN_PIPE_AT_UPLOAD_1"
error = node.query_and_get_error(
f"""
INSERT INTO
TABLE FUNCTION s3(
'http://resolver:8083/root/data/test_when_s3_broken_pipe_at_upload_is_retried',
'minio', 'minio123',
'CSV', auto, 'none'
)
SELECT
*
FROM system.numbers
LIMIT 1000000
SETTINGS
s3_max_single_part_upload_size=100,
s3_min_upload_part_size=1000000,
s3_check_objects_after_upload=0
""",
query_id=insert_query_id,
)
assert "Code: 1000" in error, error
assert (
"DB::Exception: Poco::Exception. Code: 1000, e.code() = 32, I/O error: Broken pipe"
in error
), error