From 188e1a87bd672d1f3dd1feec2258b04ca92ea908 Mon Sep 17 00:00:00 2001 From: Dmitry Novik Date: Tue, 12 Apr 2022 12:52:35 +0000 Subject: [PATCH] Check if enough memory is actually freed within timeout --- src/Common/OvercommitTracker.cpp | 31 +++++-- src/Common/OvercommitTracker.h | 5 ++ src/Common/tests/gtest_overcommit_tracker.cpp | 85 +++++++++++++++++++ 3 files changed, 114 insertions(+), 7 deletions(-) diff --git a/src/Common/OvercommitTracker.cpp b/src/Common/OvercommitTracker.cpp index d24cc24a5f8..69f0f2a616c 100644 --- a/src/Common/OvercommitTracker.cpp +++ b/src/Common/OvercommitTracker.cpp @@ -15,6 +15,7 @@ OvercommitTracker::OvercommitTracker(std::mutex & global_mutex_) , global_mutex(global_mutex_) , freed_momory(0) , required_memory(0) + , allow_release(true) {} void OvercommitTracker::setMaxWaitTime(UInt64 wait_time) @@ -56,6 +57,8 @@ bool OvercommitTracker::needToStopQuery(MemoryTracker * tracker, Int64 amount) return true; } + allow_release = true; + required_memory += amount; required_per_thread[tracker] = amount; bool timeout = !cv.wait_for(lk, max_wait_time, [this, tracker]() @@ -65,12 +68,22 @@ bool OvercommitTracker::needToStopQuery(MemoryTracker * tracker, Int64 amount) LOG_DEBUG(getLogger(), "Memory was{} freed within timeout", (timeout ? " not" : "")); required_memory -= amount; + auto still_need = required_per_thread[tracker]; // If enough memory is freed it will be 0 required_per_thread.erase(tracker); + + // If threads where not released since last call of this method, + // we can release them now. + if (allow_release && required_memory <= freed_momory) + { + assert(still_need != 0); + releaseThreads(); + } + // All required amount of memory is free now and selected query to stop doesn't know about it. // As we don't need to free memory, we can continue execution of the selected query. if (required_memory == 0 && cancellation_state == QueryCancellationState::SELECTED) reset(); - return timeout; + return timeout || still_need != 0; } void OvercommitTracker::tryContinueQueryExecutionAfterFree(Int64 amount) @@ -80,12 +93,7 @@ void OvercommitTracker::tryContinueQueryExecutionAfterFree(Int64 amount) { freed_momory += amount; if (freed_momory >= required_memory) - { - for (auto & required : required_per_thread) - required.second = 0; - freed_momory = 0; - cv.notify_all(); - } + releaseThreads(); } } @@ -101,6 +109,15 @@ void OvercommitTracker::onQueryStop(MemoryTracker * tracker) } } +void OvercommitTracker::releaseThreads() +{ + for (auto & required : required_per_thread) + required.second = 0; + freed_momory = 0; + allow_release = false; // To avoid repeating call of this method in OvercommitTracker::needToStopQuery + cv.notify_all(); +} + UserOvercommitTracker::UserOvercommitTracker(DB::ProcessList * process_list, DB::ProcessListForUser * user_process_list_) : OvercommitTracker(process_list->mutex) , user_process_list(user_process_list_) diff --git a/src/Common/OvercommitTracker.h b/src/Common/OvercommitTracker.h index d68e52d2b1d..f954f62a531 100644 --- a/src/Common/OvercommitTracker.h +++ b/src/Common/OvercommitTracker.h @@ -96,8 +96,11 @@ private: picked_tracker = nullptr; cancellation_state = QueryCancellationState::NONE; freed_momory = 0; + allow_release = true; } + void releaseThreads(); + QueryCancellationState cancellation_state; std::unordered_map required_per_thread; @@ -110,6 +113,8 @@ private: std::mutex & global_mutex; Int64 freed_momory; Int64 required_memory; + + bool allow_release; }; namespace DB diff --git a/src/Common/tests/gtest_overcommit_tracker.cpp b/src/Common/tests/gtest_overcommit_tracker.cpp index 4f742d81725..ecf4de54e6a 100644 --- a/src/Common/tests/gtest_overcommit_tracker.cpp +++ b/src/Common/tests/gtest_overcommit_tracker.cpp @@ -295,3 +295,88 @@ TEST(OvercommitTracker, GlobalFreeContinueAndAlloc3) GlobalOvercommitTrackerForTest global_overcommit_tracker(&process_list); free_continue_and_alloc_2_test(global_overcommit_tracker); } + +template +void free_continue_2_test(T & overcommit_tracker) +{ + overcommit_tracker.setMaxWaitTime(WAIT_TIME); + + static constexpr size_t THREADS = 5; + std::vector trackers(THREADS); + std::atomic need_to_stop = 0; + std::vector threads; + threads.reserve(THREADS); + + MemoryTracker picked; + overcommit_tracker.setCandidate(&picked); + + for (size_t i = 0; i < THREADS; ++i) + { + threads.push_back(std::thread([&, i](){ + if (overcommit_tracker.needToStopQuery(&trackers[i], 100)) + ++need_to_stop; + })); + } + + std::thread([&](){ + overcommit_tracker.tryContinueQueryExecutionAfterFree(300); + }).join(); + + for (auto & thread : threads) + { + thread.join(); + } + + ASSERT_EQ(need_to_stop, 2); +} + +TEST(OvercommitTracker, UserFreeContinue2) +{ + ProcessList process_list; + ProcessListForUser user_process_list(&process_list); + UserOvercommitTrackerForTest user_overcommit_tracker(&process_list, &user_process_list); + free_continue_2_test(user_overcommit_tracker); +} + +TEST(OvercommitTracker, GlobalFreeContinue2) +{ + ProcessList process_list; + GlobalOvercommitTrackerForTest global_overcommit_tracker(&process_list); + free_continue_2_test(global_overcommit_tracker); +} + +template +void query_stop_not_continue_test(T & overcommit_tracker) +{ + overcommit_tracker.setMaxWaitTime(WAIT_TIME); + + std::atomic need_to_stop = 0; + + MemoryTracker picked; + overcommit_tracker.setCandidate(&picked); + + MemoryTracker another; + auto thread = std::thread([&](){ + if (overcommit_tracker.needToStopQuery(&another, 100)) + ++need_to_stop; + }); + overcommit_tracker.onQueryStop(&picked); + thread.join(); + + ASSERT_EQ(need_to_stop, 1); +} + +TEST(OvercommitTracker, UserQueryStopNotContinue) +{ + ProcessList process_list; + ProcessListForUser user_process_list(&process_list); + UserOvercommitTrackerForTest user_overcommit_tracker(&process_list, &user_process_list); + query_stop_not_continue_test(user_overcommit_tracker); +} + +TEST(OvercommitTracker, GlobalQueryStopNotContinue) +{ + ProcessList process_list; + GlobalOvercommitTrackerForTest global_overcommit_tracker(&process_list); + query_stop_not_continue_test(global_overcommit_tracker); +}