From 0979155f2f4881ed50427205b30059c1839b2a53 Mon Sep 17 00:00:00 2001 From: Antonio Andelic Date: Wed, 16 Mar 2022 14:59:06 +0000 Subject: [PATCH] Address PR comments --- programs/server/Server.cpp | 5 +- src/IO/IOThreadPool.cpp | 4 +- src/IO/ParallelReadBuffer.cpp | 24 ++-- src/IO/ParallelReadBuffer.h | 23 ++-- src/IO/ReadWriteBufferFromHTTP.h | 124 ++++++++++-------- src/Storages/StorageURL.cpp | 37 ++++-- .../queries/0_stateless/02126_url_auth.python | 13 +- 7 files changed, 126 insertions(+), 104 deletions(-) diff --git a/programs/server/Server.cpp b/programs/server/Server.cpp index b565e6bed42..d372ff8ea65 100644 --- a/programs/server/Server.cpp +++ b/programs/server/Server.cpp @@ -555,7 +555,10 @@ if (ThreadFuzzer::instance().isEffective()) config().getUInt("thread_pool_queue_size", 10000) ); - IOThreadPool::initialize(100, 0, 10000); + IOThreadPool::initialize( + config().getUInt("max_io_thread_pool_size", 100), + config().getUInt("max_io_thread_pool_free_size", 0), + config().getUInt("io_thread_pool_queue_size", 10000)); /// Initialize global local cache for remote filesystem. if (config().has("local_cache_for_remote_fs")) diff --git a/src/IO/IOThreadPool.cpp b/src/IO/IOThreadPool.cpp index e33510d1242..4014d00d8b8 100644 --- a/src/IO/IOThreadPool.cpp +++ b/src/IO/IOThreadPool.cpp @@ -15,7 +15,7 @@ void IOThreadPool::initialize(size_t max_threads, size_t max_free_threads, size_ { if (instance) { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The IO thread pool is initialized twice"); + throw Exception(ErrorCodes::LOGICAL_ERROR, "The IO thread pool is initialized twice"); } instance = std::make_unique(max_threads, max_free_threads, queue_size, false /*shutdown_on_exception*/); @@ -25,7 +25,7 @@ ThreadPool & IOThreadPool::get() { if (!instance) { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The IO thread pool is not initialized"); + throw Exception(ErrorCodes::LOGICAL_ERROR, "The IO thread pool is not initialized"); } return *instance; diff --git a/src/IO/ParallelReadBuffer.cpp b/src/IO/ParallelReadBuffer.cpp index 9b8c438fa89..e97a2e07d2e 100644 --- a/src/IO/ParallelReadBuffer.cpp +++ b/src/IO/ParallelReadBuffer.cpp @@ -31,7 +31,7 @@ bool ParallelReadBuffer::addReaderToPool(std::unique_lock & /*buffer return false; } - auto worker = read_workers.emplace_back(std::make_shared(std::move(reader->first), reader->second)); + auto worker = read_workers.emplace_back(std::make_shared(std::move(reader))); ThreadGroupStatusPtr running_group = CurrentThread::isInitialized() && CurrentThread::get().getThreadGroup() ? CurrentThread::get().getThreadGroup() @@ -87,7 +87,7 @@ off_t ParallelReadBuffer::seek(off_t offset, int whence) std::unique_lock lock{mutex}; const auto offset_is_in_range - = [&](const auto & range) { return static_cast(offset) >= range.from && static_cast(offset) < range.to; }; + = [&](const auto & range) { return static_cast(offset) >= range.left && static_cast(offset) <= *range.right; }; while (!read_workers.empty() && (offset < current_position || !offset_is_in_range(read_workers.front()->range))) { @@ -98,7 +98,7 @@ off_t ParallelReadBuffer::seek(off_t offset, int whence) { auto & front_worker = read_workers.front(); auto & segments = front_worker->segments; - current_position = front_worker->range.from; + current_position = front_worker->range.left; while (true) { next_condvar.wait(lock, [&] { return emergency_stop || !segments.empty(); }); @@ -151,20 +151,22 @@ off_t ParallelReadBuffer::getPosition() bool ParallelReadBuffer::currentWorkerReady() const { - return !read_workers.empty() && (read_workers.front()->finished || !read_workers.front()->segments.empty()); + assert(!read_workers.empty()); + return read_workers.front()->finished || !read_workers.front()->segments.empty(); } bool ParallelReadBuffer::currentWorkerCompleted() const { - return !read_workers.empty() && read_workers.front()->finished && read_workers.front()->segments.empty(); + assert(!read_workers.empty()); + return read_workers.front()->finished && read_workers.front()->segments.empty(); } void ParallelReadBuffer::handleEmergencyStop() { + // this can only be called from the main thread when there is an exception + assert(background_exception); if (background_exception) std::rethrow_exception(background_exception); - else - throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Emergency stop"); } bool ParallelReadBuffer::nextImpl() @@ -183,17 +185,17 @@ bool ParallelReadBuffer::nextImpl() return emergency_stop || currentWorkerReady(); }); - if (emergency_stop) - handleEmergencyStop(); - bool worker_removed = false; /// Remove completed units - while (!read_workers.empty() && currentWorkerCompleted()) + while (!read_workers.empty() && currentWorkerCompleted() && !emergency_stop) { read_workers.pop_front(); worker_removed = true; } + if (emergency_stop) + handleEmergencyStop(); + if (worker_removed) addReaders(lock); diff --git a/src/IO/ParallelReadBuffer.h b/src/IO/ParallelReadBuffer.h index ab0aee959eb..5d67b3e5b99 100644 --- a/src/IO/ParallelReadBuffer.h +++ b/src/IO/ParallelReadBuffer.h @@ -67,18 +67,12 @@ private: }; public: - struct Range - { - size_t from; - size_t to; - }; - - using ReaderWithRange = std::pair; + SeekableReadBufferPtr read_buffer; class ReadBufferFactory { public: - virtual std::optional getReader() = 0; + virtual SeekableReadBufferPtr getReader() = 0; virtual ~ReadBufferFactory() = default; virtual off_t seek(off_t off, int whence) = 0; virtual std::optional getTotalSize() = 0; @@ -96,9 +90,10 @@ private: /// Reader in progress with a list of read segments struct ReadWorker { - explicit ReadWorker(ReadBufferPtr reader_, const Range & range_) - : reader(reader_), range(range_), bytes_left(range_.to - range_.from) + explicit ReadWorker(SeekableReadBufferPtr reader_) : reader(std::move(reader_)), range(reader->getRemainingReadRange()) { + assert(range.right); + bytes_left = *range.right - range.left + 1; } Segment nextSegment() @@ -106,14 +101,14 @@ private: assert(!segments.empty()); auto next_segment = std::move(segments.front()); segments.pop_front(); - range.from += next_segment.size(); + range.left += next_segment.size(); return next_segment; } - ReadBufferPtr reader; + SeekableReadBufferPtr reader; std::deque segments; bool finished{false}; - Range range; + SeekableReadBuffer::Range range; size_t bytes_left{0}; }; @@ -124,7 +119,7 @@ private: /// First worker in deque processed and flushed all data bool currentWorkerCompleted() const; - [[noreturn]] void handleEmergencyStop(); + void handleEmergencyStop(); void addReaders(std::unique_lock & buffer_lock); bool addReaderToPool(std::unique_lock & buffer_lock); diff --git a/src/IO/ReadWriteBufferFromHTTP.h b/src/IO/ReadWriteBufferFromHTTP.h index 4d1606aa517..061dd772212 100644 --- a/src/IO/ReadWriteBufferFromHTTP.h +++ b/src/IO/ReadWriteBufferFromHTTP.h @@ -207,19 +207,7 @@ namespace detail { try { - call(response, Poco::Net::HTTPRequest::HTTP_HEAD); - - while (isRedirect(response.getStatus())) - { - Poco::URI uri_redirect(response.get("Location")); - if (remote_host_filter) - remote_host_filter->checkURL(uri_redirect); - - session->updateSession(uri_redirect); - - istr = callImpl(uri_redirect, response, method); - } - + callWithRedirects(response, Poco::Net::HTTPRequest::HTTP_HEAD); break; } catch (const Poco::Exception & e) @@ -324,7 +312,36 @@ namespace detail } } - void call(Poco::Net::HTTPResponse & response, const String & method_) + static bool isRetriableError(const Poco::Net::HTTPResponse::HTTPStatus http_status) noexcept + { + constexpr std::array non_retriable_errors{ + Poco::Net::HTTPResponse::HTTPStatus::HTTP_BAD_REQUEST, + Poco::Net::HTTPResponse::HTTPStatus::HTTP_UNAUTHORIZED, + Poco::Net::HTTPResponse::HTTPStatus::HTTP_NOT_FOUND, + Poco::Net::HTTPResponse::HTTPStatus::HTTP_FORBIDDEN, + Poco::Net::HTTPResponse::HTTPStatus::HTTP_METHOD_NOT_ALLOWED}; + + return std::all_of( + non_retriable_errors.begin(), non_retriable_errors.end(), [&](const auto status) { return http_status != status; }); + } + + void callWithRedirects(Poco::Net::HTTPResponse & response, const String & method_, bool throw_on_all_errors = false) + { + call(response, method_, throw_on_all_errors); + + while (isRedirect(response.getStatus())) + { + Poco::URI uri_redirect(response.get("Location")); + if (remote_host_filter) + remote_host_filter->checkURL(uri_redirect); + + session->updateSession(uri_redirect); + + istr = callImpl(uri_redirect, response, method); + } + } + + void call(Poco::Net::HTTPResponse & response, const String & method_, bool throw_on_all_errors = false) { try { @@ -332,18 +349,18 @@ namespace detail } catch (...) { + if (throw_on_all_errors) + { + throw; + } + auto http_status = response.getStatus(); - if (http_status == Poco::Net::HTTPResponse::HTTPStatus::HTTP_NOT_FOUND - && http_skip_not_found_url) + if (http_status == Poco::Net::HTTPResponse::HTTPStatus::HTTP_NOT_FOUND && http_skip_not_found_url) { initialization_error = InitializeError::SKIP_NOT_FOUND_URL; } - else if (http_status == Poco::Net::HTTPResponse::HTTPStatus::HTTP_BAD_REQUEST - || http_status == Poco::Net::HTTPResponse::HTTPStatus::HTTP_UNAUTHORIZED - || http_status == Poco::Net::HTTPResponse::HTTPStatus::HTTP_NOT_FOUND - || http_status == Poco::Net::HTTPResponse::HTTPStatus::HTTP_FORBIDDEN - || http_status == Poco::Net::HTTPResponse::HTTPStatus::HTTP_METHOD_NOT_ALLOWED) + else if (!isRetriableError(http_status)) { initialization_error = InitializeError::NON_RETRIABLE_ERROR; exception = std::current_exception(); @@ -579,6 +596,8 @@ namespace detail return offset_; } + SeekableReadBuffer::Range getRemainingReadRange() const override { return {getOffset(), read_range.end}; } + std::string getResponseCookie(const std::string & name, const std::string & def) const { for (const auto & cookie : cookies) @@ -620,35 +639,35 @@ class RangeGenerator { public: explicit RangeGenerator(size_t total_size_, size_t range_step_, size_t range_start = 0) - : from_range(range_start), range_step(range_step_), total_size(total_size_) + : from(range_start), range_step(range_step_), total_size(total_size_) { } - size_t totalRanges() const { return static_cast(round(static_cast(total_size - from_range) / range_step)); } + size_t totalRanges() const { return static_cast(round(static_cast(total_size - from) / range_step)); } using Range = std::pair; // return upper exclusive range of values, i.e. [from_range, to_range> std::optional nextRange() { - if (from_range >= total_size) + if (from >= total_size) { return std::nullopt; } - auto to_range = from_range + range_step; - if (to_range >= total_size) + auto to = from + range_step; + if (to >= total_size) { - to_range = total_size; + to = total_size; } - Range range{from_range, to_range}; - from_range = to_range; + Range range{from, to}; + from = to; return std::move(range); } private: - size_t from_range; + size_t from; size_t range_step; size_t total_size; }; @@ -731,34 +750,30 @@ public: { } - using Range = ParallelReadBuffer::Range; - using ReaderWithRange = ParallelReadBuffer::ReaderWithRange; - std::optional getReader() override + SeekableReadBufferPtr getReader() override { const auto next_range = range_generator.nextRange(); if (!next_range) { - return std::nullopt; + return nullptr; } - return std::pair{ - std::make_shared( - uri, - method, - out_stream_callback, - timeouts, - credentials, - max_redirects, - buffer_size, - settings, - http_header_entries, - // HTTP Range has inclusive bounds, i.e. [from, to] - ReadWriteBufferFromHTTP::Range{next_range->first, next_range->second - 1}, - remote_host_filter, - delay_initialization, - use_external_buffer, - skip_not_found_url), - Range{next_range->first, next_range->second}}; + return std::make_shared( + uri, + method, + out_stream_callback, + timeouts, + credentials, + max_redirects, + buffer_size, + settings, + http_header_entries, + // HTTP Range has inclusive bounds, i.e. [from, to] + ReadWriteBufferFromHTTP::Range{next_range->first, next_range->second - 1}, + remote_host_filter, + delay_initialization, + use_external_buffer, + skip_not_found_url); } off_t seek(off_t off, [[maybe_unused]] int whence) override @@ -767,10 +782,7 @@ public: return off; } - std::optional getTotalSize() override - { - return total_object_size; - } + std::optional getTotalSize() override { return total_object_size; } private: RangeGenerator range_generator; diff --git a/src/Storages/StorageURL.cpp b/src/Storages/StorageURL.cpp index ddcb9d34788..28de10669e4 100644 --- a/src/Storages/StorageURL.cpp +++ b/src/Storages/StorageURL.cpp @@ -286,7 +286,23 @@ namespace /* skip_url_not_found_error */ skip_url_not_found_error); Poco::Net::HTTPResponse res; - buffer.call(res, Poco::Net::HTTPRequest::HTTP_HEAD); + + for (size_t i = 0; i < settings.http_max_tries; ++i) + { + try + { + buffer.callWithRedirects(res, Poco::Net::HTTPRequest::HTTP_HEAD, true); + break; + } + catch (...) + { + if (!ReadWriteBufferFromHTTP::isRetriableError(res.getStatus())) + { + throw; + } + } + } + // to check if Range header is supported, we need to send a request with it set const bool supports_ranges = res.has("Accept-Ranges") && res.get("Accept-Ranges") == "bytes"; LOG_TRACE( @@ -294,7 +310,8 @@ namespace fmt::runtime(supports_ranges ? "HTTP Range is supported" : "HTTP Range is not supported")); - if (supports_ranges && res.getStatus() == Poco::Net::HTTPResponse::HTTP_PARTIAL_CONTENT && res.hasContentLength()) + if (supports_ranges && res.getStatus() == Poco::Net::HTTPResponse::HTTP_PARTIAL_CONTENT + && res.hasContentLength()) { LOG_TRACE( &Poco::Logger::get("StorageURLSource"), @@ -324,11 +341,13 @@ namespace chooseCompressionMethod(request_uri.getPath(), compression_method)); } } - catch (...) + catch (const Exception & e) { LOG_TRACE( &Poco::Logger::get(__PRETTY_FUNCTION__), - "Failed to setup ParallelReadBuffer. Falling back to the single-threaded buffer"); + "Failed to setup ParallelReadBuffer because of an exception:\n{}.\nFalling back to the single-threaded " + "buffer", + e.what()); } } @@ -611,15 +630,9 @@ Pipe IStorageURLBase::read( Pipes pipes; pipes.reserve(num_streams); - size_t remaining_download_threads = max_download_threads; - + size_t download_threads = num_streams >= max_download_threads ? 1 : (max_download_threads / num_streams); for (size_t i = 0; i < num_streams; ++i) { - size_t current_need_download_threads = num_streams >= max_download_threads ? 1 : (max_download_threads / num_streams); - size_t current_download_threads = std::min(current_need_download_threads, remaining_download_threads); - remaining_download_threads -= current_download_threads; - current_download_threads = std::max(static_cast(1), current_download_threads); - pipes.emplace_back(std::make_shared( uri_info, getReadMethod(), @@ -633,7 +646,7 @@ Pipe IStorageURLBase::read( max_block_size, ConnectionTimeouts::getHTTPTimeouts(local_context), compression_method, - current_download_threads, + download_threads, headers, params, /* glob_url */ true)); diff --git a/tests/queries/0_stateless/02126_url_auth.python b/tests/queries/0_stateless/02126_url_auth.python index b60438de4ed..57b16fb413e 100644 --- a/tests/queries/0_stateless/02126_url_auth.python +++ b/tests/queries/0_stateless/02126_url_auth.python @@ -121,18 +121,14 @@ class CSVHTTPServer(BaseHTTPRequestHandler): class HTTPServerV6(HTTPServer): address_family = socket.AF_INET6 -def start_server(requests_amount): +def start_server(): if IS_IPV6: httpd = HTTPServerV6(HTTP_SERVER_ADDRESS, CSVHTTPServer) else: httpd = HTTPServer(HTTP_SERVER_ADDRESS, CSVHTTPServer) - def real_func(): - for i in range(requests_amount): - httpd.handle_request() - - t = threading.Thread(target=real_func) - return t + t = threading.Thread(target=httpd.serve_forever) + return t, httpd # test section @@ -217,9 +213,10 @@ def main(): query : 'hello, world', } - t = start_server(len(list(select_requests_url_auth.keys())) * 2) + t, httpd = start_server() t.start() test_select(requests=list(select_requests_url_auth.keys()), answers=list(select_requests_url_auth.values()), test_data=test_data) + httpd.shutdown() t.join() print("PASSED")