Fix data race in copyFromIStreamWithProgressCallback

This commit is contained in:
Michael Kolupaev 2023-06-26 21:49:44 +00:00
parent 00fcb8aceb
commit df71dcd94d
6 changed files with 35 additions and 27 deletions

View File

@ -62,6 +62,7 @@ struct HTTPSessionReuseTag
{
};
void markSessionForReuse(Poco::Net::HTTPSession & session);
void markSessionForReuse(HTTPSessionPtr session);
void markSessionForReuse(PooledHTTPSessionPtr session);

View File

@ -221,13 +221,12 @@ bool ReadBufferFromS3::nextImpl()
size_t ReadBufferFromS3::readBigAt(char * to, size_t n, size_t range_begin, const std::function<bool(size_t)> & progress_callback)
{
if (n == 0)
return 0;
size_t initial_n = n;
size_t sleep_time_with_backoff_milliseconds = 100;
for (size_t attempt = 0;; ++attempt)
for (size_t attempt = 0; n > 0; ++attempt)
{
bool last_attempt = attempt + 1 >= request_settings.max_single_read_retries;
size_t bytes_copied = 0;
ProfileEventTimeIncrement<Microseconds> watch(ProfileEvents::ReadBufferFromS3Microseconds);
@ -236,14 +235,12 @@ size_t ReadBufferFromS3::readBigAt(char * to, size_t n, size_t range_begin, cons
auto result = sendRequest(range_begin, range_begin + n - 1);
std::istream & istr = result.GetBody();
size_t bytes = copyFromIStreamWithProgressCallback(istr, to, n, progress_callback);
copyFromIStreamWithProgressCallback(istr, to, n, progress_callback, &bytes_copied);
ProfileEvents::increment(ProfileEvents::ReadBufferFromS3Bytes, bytes);
ProfileEvents::increment(ProfileEvents::ReadBufferFromS3Bytes, bytes_copied);
if (read_settings.remote_throttler)
read_settings.remote_throttler->add(bytes, ProfileEvents::RemoteReadThrottlerBytes, ProfileEvents::RemoteReadThrottlerSleepMicroseconds);
return bytes;
read_settings.remote_throttler->add(bytes_copied, ProfileEvents::RemoteReadThrottlerBytes, ProfileEvents::RemoteReadThrottlerSleepMicroseconds);
}
catch (Poco::Exception & e)
{
@ -253,7 +250,13 @@ size_t ReadBufferFromS3::readBigAt(char * to, size_t n, size_t range_begin, cons
sleepForMilliseconds(sleep_time_with_backoff_milliseconds);
sleep_time_with_backoff_milliseconds *= 2;
}
range_begin += bytes_copied;
to += bytes_copied;
n -= bytes_copied;
}
return initial_n;
}
bool ReadBufferFromS3::processException(Poco::Exception & e, size_t read_offset, size_t attempt) const

View File

@ -594,16 +594,14 @@ size_t ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::readBigAt(char * to, si
/// This ensures we've sent at least one HTTP request and populated saved_uri_redirect.
chassert(file_info && file_info->seekable);
if (n == 0)
return 0;
Poco::URI uri_ = saved_uri_redirect.value_or(uri);
if (uri_.getPath().empty())
uri_.setPath("/");
size_t initial_n = n;
size_t milliseconds_to_wait = settings.http_retry_initial_backoff_ms;
for (size_t attempt = 0;; ++attempt)
for (size_t attempt = 0; n > 0; ++attempt)
{
bool last_attempt = attempt + 1 >= settings.http_max_tries;
@ -616,6 +614,7 @@ size_t ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::readBigAt(char * to, si
Poco::Net::HTTPResponse response;
std::istream * result_istr;
size_t bytes_copied = 0;
try
{
@ -629,17 +628,14 @@ size_t ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::readBigAt(char * to, si
"Expected 206 Partial Content, got {} when reading {} range [{}, {})",
toString(response.getStatus()), uri_.toString(), offset, offset + n);
bool cancelled;
size_t r = copyFromIStreamWithProgressCallback(*result_istr, to, n, progress_callback, &cancelled);
if (!cancelled)
copyFromIStreamWithProgressCallback(*result_istr, to, n, progress_callback, &bytes_copied);
if (bytes_copied == n)
{
result_istr->ignore(UINT64_MAX);
/// Response was fully read.
markSessionForReuse(sess);
markSessionForReuse(*sess);
ProfileEvents::increment(ProfileEvents::ReadWriteBufferFromHTTPPreservedSessions);
}
return r;
}
catch (const Poco::Exception & e)
{
@ -664,9 +660,15 @@ size_t ReadWriteBufferFromHTTPBase<UpdatableSessionPtr>::readBigAt(char * to, si
sleepForMilliseconds(milliseconds_to_wait);
milliseconds_to_wait = std::min(milliseconds_to_wait * 2, settings.http_retry_max_backoff_ms);
continue;
}
/// Make sure retries don't re-read the bytes that we've already reported to progress_callback.
offset += bytes_copied;
to += bytes_copied;
n -= bytes_copied;
}
return initial_n;
}
template <typename UpdatableSessionPtr>

View File

@ -64,7 +64,7 @@ std::unique_ptr<SeekableReadBuffer> wrapSeekableReadBufferPointer(SeekableReadBu
return std::make_unique<SeekableReadBufferWrapper<SeekableReadBufferPtr>>(*ptr, SeekableReadBufferPtr{ptr});
}
size_t copyFromIStreamWithProgressCallback(std::istream & istr, char * to, size_t n, const std::function<bool(size_t)> & progress_callback, bool * out_cancelled)
void copyFromIStreamWithProgressCallback(std::istream & istr, char * to, size_t n, const std::function<bool(size_t)> & progress_callback, size_t * out_bytes_copied, bool * out_cancelled)
{
const size_t chunk = DBMS_DEFAULT_BUFFER_SIZE;
if (out_cancelled)
@ -82,6 +82,7 @@ size_t copyFromIStreamWithProgressCallback(std::istream & istr, char * to, size_
bool cancelled = false;
if (gcount && progress_callback)
cancelled = progress_callback(copied);
*out_bytes_copied = copied;
if (gcount != to_copy)
{
@ -103,7 +104,7 @@ size_t copyFromIStreamWithProgressCallback(std::istream & istr, char * to, size_
}
}
return copied;
*out_bytes_copied = copied;
}
}

View File

@ -98,6 +98,7 @@ std::unique_ptr<SeekableReadBuffer> wrapSeekableReadBufferReference(SeekableRead
std::unique_ptr<SeekableReadBuffer> wrapSeekableReadBufferPointer(SeekableReadBufferPtr ptr);
/// Helper for implementing readBigAt().
size_t copyFromIStreamWithProgressCallback(std::istream & istr, char * to, size_t n, const std::function<bool(size_t)> & progress_callback, bool * out_cancelled = nullptr);
/// Updates *out_bytes_copied after each call to the callback, as well as at the end.
void copyFromIStreamWithProgressCallback(std::istream & istr, char * to, size_t n, const std::function<bool(size_t)> & progress_callback, size_t * out_bytes_copied, bool * out_cancelled = nullptr);
}

View File

@ -1018,11 +1018,11 @@ def test_url_reconnect_in_the_middle(started_cluster):
def select():
global result
result = instance.query(
f"""select sum(cityHash64(x)) from (select toUInt64(id) + sleep(0.1) as x from
f"""select count(), sum(cityHash64(x)) from (select toUInt64(id) + sleep(0.1) as x from
url('http://{started_cluster.minio_host}:{started_cluster.minio_port}/{bucket}/{filename}', 'TSV', '{table_format}')
settings http_max_tries = 10, http_retry_max_backoff_ms=2000, http_send_timeout=1, http_receive_timeout=1)"""
)
assert int(result) == 3914219105369203805
assert result == "1000000\t3914219105369203805\n"
thread = threading.Thread(target=select)
thread.start()
@ -1035,7 +1035,7 @@ def test_url_reconnect_in_the_middle(started_cluster):
thread.join()
assert int(result) == 3914219105369203805
assert result == "1000000\t3914219105369203805\n"
def test_seekable_formats(started_cluster):