diff --git a/src/Common/OvercommitTracker.cpp b/src/Common/OvercommitTracker.cpp index 94b42e77e4e..d24cc24a5f8 100644 --- a/src/Common/OvercommitTracker.cpp +++ b/src/Common/OvercommitTracker.cpp @@ -55,20 +55,17 @@ bool OvercommitTracker::needToStopQuery(MemoryTracker * tracker, Int64 amount) cancellation_state = QueryCancellationState::RUNNING; return true; } + required_memory += amount; - bool timeout = !cv.wait_for(lk, max_wait_time, [this]() + required_per_thread[tracker] = amount; + bool timeout = !cv.wait_for(lk, max_wait_time, [this, tracker]() { - return freed_momory >= required_memory || cancellation_state == QueryCancellationState::NONE; + return required_per_thread[tracker] == 0 || cancellation_state == QueryCancellationState::NONE; }); LOG_DEBUG(getLogger(), "Memory was{} freed within timeout", (timeout ? " not" : "")); - // If query cancellation is still running, it's possible that other queries will reach - // hard limit and end up on waiting on condition variable. - // If so we need to specify that some part of freed memory is acquired at this moment. - if (!timeout && cancellation_state != QueryCancellationState::NONE) - freed_momory -= amount; - required_memory -= amount; + required_per_thread.erase(tracker); // All required amount of memory is free now and selected query to stop doesn't know about it. // As we don't need to free memory, we can continue execution of the selected query. if (required_memory == 0 && cancellation_state == QueryCancellationState::SELECTED) @@ -83,7 +80,12 @@ void OvercommitTracker::tryContinueQueryExecutionAfterFree(Int64 amount) { freed_momory += amount; if (freed_momory >= required_memory) + { + for (auto & required : required_per_thread) + required.second = 0; + freed_momory = 0; cv.notify_all(); + } } } diff --git a/src/Common/OvercommitTracker.h b/src/Common/OvercommitTracker.h index 9b6364d91c1..d68e52d2b1d 100644 --- a/src/Common/OvercommitTracker.h +++ b/src/Common/OvercommitTracker.h @@ -34,6 +34,13 @@ struct OvercommitRatio class MemoryTracker; +enum class QueryCancellationState +{ + NONE = 0, // Hard limit is not reached, there is no selected query to kill. + SELECTED = 1, // Hard limit is reached, query to stop was chosen but it still is not aware of cancellation. + RUNNING = 2, // Hard limit is reached, selected query has started the process of cancellation. +}; + // Usually it's hard to set some reasonable hard memory limit // (especially, the default value). This class introduces new // mechanisim for the limiting of memory usage. @@ -91,15 +98,10 @@ private: freed_momory = 0; } - enum class QueryCancellationState - { - NONE, // Hard limit is not reached, there is no selected query to kill. - SELECTED, // Hard limit is reached, query to stop was chosen but it still is not aware of cancellation. - RUNNING, // Hard limit is reached, selected query has started the process of cancellation. - }; - QueryCancellationState cancellation_state; + std::unordered_map required_per_thread; + // Global mutex which is used in ProcessList to synchronize // insertion and deletion of queries. // OvercommitTracker::pickQueryToExcludeImpl() implementations @@ -123,7 +125,7 @@ struct UserOvercommitTracker : OvercommitTracker ~UserOvercommitTracker() override = default; protected: - void pickQueryToExcludeImpl() override final; + void pickQueryToExcludeImpl() override; Poco::Logger * getLogger() override final { return logger; } private: @@ -138,7 +140,7 @@ struct GlobalOvercommitTracker : OvercommitTracker ~GlobalOvercommitTracker() override = default; protected: - void pickQueryToExcludeImpl() override final; + void pickQueryToExcludeImpl() override; Poco::Logger * getLogger() override final { return logger; } private: diff --git a/src/Common/tests/gtest_overcommit_tracker.cpp b/src/Common/tests/gtest_overcommit_tracker.cpp new file mode 100644 index 00000000000..4f742d81725 --- /dev/null +++ b/src/Common/tests/gtest_overcommit_tracker.cpp @@ -0,0 +1,297 @@ +#include +#include +#include + +#include +#include +#include + +using namespace DB; + +template +struct OvercommitTrackerForTest : BaseTracker +{ + template + explicit OvercommitTrackerForTest(Ts && ...args) + : BaseTracker(std::move(args)...) + {} + + void setCandidate(MemoryTracker * candidate) + { + tracker = candidate; + } + +protected: + void pickQueryToExcludeImpl() override + { + BaseTracker::picked_tracker = tracker; + } + + MemoryTracker * tracker; +}; + +using UserOvercommitTrackerForTest = OvercommitTrackerForTest; +using GlobalOvercommitTrackerForTest = OvercommitTrackerForTest; + +static constexpr UInt64 WAIT_TIME = 3'000'000; + +template +void free_not_continue_test(T & overcommit_tracker) +{ + overcommit_tracker.setMaxWaitTime(WAIT_TIME); + + static constexpr size_t THREADS = 5; + std::vector trackers(THREADS); + std::atomic need_to_stop = 0; + std::vector threads; + threads.reserve(THREADS); + + MemoryTracker picked; + overcommit_tracker.setCandidate(&picked); + + for (size_t i = 0; i < THREADS; ++i) + { + threads.push_back(std::thread([&, i](){ + if (overcommit_tracker.needToStopQuery(&trackers[i], 100)) + ++need_to_stop; + })); + } + + std::thread([&](){ + overcommit_tracker.tryContinueQueryExecutionAfterFree(50); + }).join(); + + for (auto & thread : threads) + { + thread.join(); + } + + ASSERT_EQ(need_to_stop, THREADS); +} + +TEST(OvercommitTracker, UserFreeNotContinue) +{ + ProcessList process_list; + ProcessListForUser user_process_list(&process_list); + UserOvercommitTrackerForTest user_overcommit_tracker(&process_list, &user_process_list); + free_not_continue_test(user_overcommit_tracker); +} + +TEST(OvercommitTracker, GlobalFreeNotContinue) +{ + ProcessList process_list; + GlobalOvercommitTrackerForTest global_overcommit_tracker(&process_list); + free_not_continue_test(global_overcommit_tracker); +} + +template +void free_continue_test(T & overcommit_tracker) +{ + overcommit_tracker.setMaxWaitTime(WAIT_TIME); + + static constexpr size_t THREADS = 5; + std::vector trackers(THREADS); + std::atomic need_to_stop = 0; + std::vector threads; + threads.reserve(THREADS); + + MemoryTracker picked; + overcommit_tracker.setCandidate(&picked); + + for (size_t i = 0; i < THREADS; ++i) + { + threads.push_back(std::thread([&, i](){ + if (overcommit_tracker.needToStopQuery(&trackers[i], 100)) + ++need_to_stop; + })); + } + + std::thread([&](){ + overcommit_tracker.tryContinueQueryExecutionAfterFree(5000); + }).join(); + + for (auto & thread : threads) + { + thread.join(); + } + + ASSERT_EQ(need_to_stop, 0); +} + +TEST(OvercommitTracker, UserFreeContinue) +{ + ProcessList process_list; + ProcessListForUser user_process_list(&process_list); + UserOvercommitTrackerForTest user_overcommit_tracker(&process_list, &user_process_list); + free_continue_test(user_overcommit_tracker); +} + +TEST(OvercommitTracker, GlobalFreeContinue) +{ + ProcessList process_list; + GlobalOvercommitTrackerForTest global_overcommit_tracker(&process_list); + free_continue_test(global_overcommit_tracker); +} + +template +void free_continue_and_alloc_test(T & overcommit_tracker) +{ + overcommit_tracker.setMaxWaitTime(WAIT_TIME); + + static constexpr size_t THREADS = 5; + std::vector trackers(THREADS); + std::atomic need_to_stop = 0; + std::vector threads; + threads.reserve(THREADS); + + MemoryTracker picked; + overcommit_tracker.setCandidate(&picked); + + for (size_t i = 0; i < THREADS; ++i) + { + threads.push_back(std::thread([&, i](){ + if (overcommit_tracker.needToStopQuery(&trackers[i], 100)) + ++need_to_stop; + })); + } + + bool stopped_next = false; + std::thread([&](){ + MemoryTracker failed; + overcommit_tracker.tryContinueQueryExecutionAfterFree(5000); + stopped_next = overcommit_tracker.needToStopQuery(&failed, 100); + }).join(); + + for (auto & thread : threads) + { + thread.join(); + } + + ASSERT_EQ(need_to_stop, 0); + ASSERT_EQ(stopped_next, true); +} + +TEST(OvercommitTracker, UserFreeContinueAndAlloc) +{ + ProcessList process_list; + ProcessListForUser user_process_list(&process_list); + UserOvercommitTrackerForTest user_overcommit_tracker(&process_list, &user_process_list); + free_continue_and_alloc_test(user_overcommit_tracker); +} + +TEST(OvercommitTracker, GlobalFreeContinueAndAlloc) +{ + ProcessList process_list; + GlobalOvercommitTrackerForTest global_overcommit_tracker(&process_list); + free_continue_and_alloc_test(global_overcommit_tracker); +} + +template +void free_continue_and_alloc_2_test(T & overcommit_tracker) +{ + overcommit_tracker.setMaxWaitTime(WAIT_TIME); + + static constexpr size_t THREADS = 5; + std::vector trackers(THREADS); + std::atomic need_to_stop = 0; + std::vector threads; + threads.reserve(THREADS); + + MemoryTracker picked; + overcommit_tracker.setCandidate(&picked); + + for (size_t i = 0; i < THREADS; ++i) + { + threads.push_back(std::thread([&, i](){ + if (overcommit_tracker.needToStopQuery(&trackers[i], 100)) + ++need_to_stop; + })); + } + + bool stopped_next = false; + threads.push_back(std::thread([&](){ + MemoryTracker failed; + overcommit_tracker.tryContinueQueryExecutionAfterFree(5000); + stopped_next = overcommit_tracker.needToStopQuery(&failed, 100); + })); + + overcommit_tracker.tryContinueQueryExecutionAfterFree(90); + + for (auto & thread : threads) + { + thread.join(); + } + + ASSERT_EQ(need_to_stop, 0); + ASSERT_EQ(stopped_next, true); +} + +TEST(OvercommitTracker, UserFreeContinueAndAlloc2) +{ + ProcessList process_list; + ProcessListForUser user_process_list(&process_list); + UserOvercommitTrackerForTest user_overcommit_tracker(&process_list, &user_process_list); + free_continue_and_alloc_2_test(user_overcommit_tracker); +} + +TEST(OvercommitTracker, GlobalFreeContinueAndAlloc2) +{ + ProcessList process_list; + GlobalOvercommitTrackerForTest global_overcommit_tracker(&process_list); + free_continue_and_alloc_2_test(global_overcommit_tracker); +} + +template +void free_continue_and_alloc_3_test(T & overcommit_tracker) +{ + overcommit_tracker.setMaxWaitTime(WAIT_TIME); + + static constexpr size_t THREADS = 5; + std::vector trackers(THREADS); + std::atomic need_to_stop = 0; + std::vector threads; + threads.reserve(THREADS); + + MemoryTracker picked; + overcommit_tracker.setCandidate(&picked); + + for (size_t i = 0; i < THREADS; ++i) + { + threads.push_back(std::thread([&, i](){ + if (overcommit_tracker.needToStopQuery(&trackers[i], 100)) + ++need_to_stop; + })); + } + + bool stopped_next = false; + threads.push_back(std::thread([&](){ + MemoryTracker failed; + overcommit_tracker.tryContinueQueryExecutionAfterFree(5000); + stopped_next = overcommit_tracker.needToStopQuery(&failed, 100); + })); + + overcommit_tracker.tryContinueQueryExecutionAfterFree(100); + + for (auto & thread : threads) + { + thread.join(); + } + + ASSERT_EQ(need_to_stop, 0); + ASSERT_EQ(stopped_next, false); +} + +TEST(OvercommitTracker, UserFreeContinueAndAlloc3) +{ + ProcessList process_list; + ProcessListForUser user_process_list(&process_list); + UserOvercommitTrackerForTest user_overcommit_tracker(&process_list, &user_process_list); + free_continue_and_alloc_2_test(user_overcommit_tracker); +} + +TEST(OvercommitTracker, GlobalFreeContinueAndAlloc3) +{ + ProcessList process_list; + GlobalOvercommitTrackerForTest global_overcommit_tracker(&process_list); + free_continue_and_alloc_2_test(global_overcommit_tracker); +}