diff --git a/src/Common/ErrorCodes.cpp b/src/Common/ErrorCodes.cpp index 95333eccbcd..dec63d114eb 100644 --- a/src/Common/ErrorCodes.cpp +++ b/src/Common/ErrorCodes.cpp @@ -643,6 +643,7 @@ M(672, INVALID_SCHEDULER_NODE) \ M(673, RESOURCE_ACCESS_DENIED) \ M(674, RESOURCE_NOT_FOUND) \ + M(675, THREAD_WAS_CANCELLED) \ \ M(999, KEEPER_EXCEPTION) \ M(1000, POCO_EXCEPTION) \ diff --git a/src/Common/Threading.cpp b/src/Common/Threading.cpp new file mode 100644 index 00000000000..4d135ef93a6 --- /dev/null +++ b/src/Common/Threading.cpp @@ -0,0 +1,511 @@ +#include +#include + +#ifdef OS_LINUX /// Because of 'sigqueue' functions, RT signals and futex. + +#include + +#include + +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int THREAD_WAS_CANCELLED; +} + +namespace +{ + inline long futexWait(void * address, UInt32 value) + { + return syscall(SYS_futex, address, FUTEX_WAIT_PRIVATE, value, nullptr, nullptr, 0); + } + + inline long futexWake(void * address, int count) + { + return syscall(SYS_futex, address, FUTEX_WAKE_PRIVATE, count, nullptr, nullptr, 0); + } + + // inline void waitFetch(std::atomic & address, UInt32 & value) + // { + // futexWait(&address, value); + // value = address.load(); + // } + + // inline void wakeOne(std::atomic & address) + // { + // futexWake(&address, 1); + // } + + // inline void wakeAll(std::atomic & address) + // { + // futexWake(&address, INT_MAX); + // } + + inline constexpr UInt32 lowerValue(UInt64 value) + { + return UInt32(value & 0xffffffffull); + } + + inline constexpr UInt32 upperValue(UInt64 value) + { + return UInt32(value >> 32ull); + } + + inline UInt32 * lowerAddress(void * address) + { + return reinterpret_cast(address) + (std::endian::native == std::endian::big); + } + + inline UInt32 * upperAddress(void * address) + { + return reinterpret_cast(address) + (std::endian::native == std::endian::little); + } + + inline void waitLowerFetch(std::atomic & address, UInt64 & value) + { + futexWait(lowerAddress(&address), lowerValue(value)); + value = address.load(); + } + + inline bool cancellableWaitLowerFetch(std::atomic & address, UInt64 & value) + { + bool res = CancelToken::local().wait(lowerAddress(&address), lowerValue(value)); + value = address.load(); + return res; + } + + inline void wakeLowerOne(std::atomic & address) + { + syscall(SYS_futex, lowerAddress(&address), FUTEX_WAKE_PRIVATE, 1, nullptr, nullptr, 0); + } + + // inline void wakeLowerAll(std::atomic & address) + // { + // syscall(SYS_futex, lowerAddress(&address), FUTEX_WAKE_PRIVATE, INT_MAX, nullptr, nullptr, 0); + // } + + inline void waitUpperFetch(std::atomic & address, UInt64 & value) + { + futexWait(upperAddress(&address), upperValue(value)); + value = address.load(); + } + + inline bool cancellableWaitUpperFetch(std::atomic & address, UInt64 & value) + { + bool res = CancelToken::local().wait(upperAddress(&address), upperValue(value)); + value = address.load(); + return res; + } + + // inline void wakeUpperOne(std::atomic & address) + // { + // syscall(SYS_futex, upperAddress(&address), FUTEX_WAKE_PRIVATE, 1, nullptr, nullptr, 0); + // } + + inline void wakeUpperAll(std::atomic & address) + { + syscall(SYS_futex, upperAddress(&address), FUTEX_WAKE_PRIVATE, INT_MAX, nullptr, nullptr, 0); + } +} + +CancelToken::Registry::Registry() +{ + // setupCancelSignalHandler(); +} + +void CancelToken::Registry::insert(CancelToken * token) +{ + std::lock_guard lock(mutex); + threads[token->thread_id] = token; +} + +void CancelToken::Registry::remove(CancelToken * token) +{ + std::lock_guard lock(mutex); + threads.erase(token->thread_id); +} + +void CancelToken::Registry::signal(UInt64 tid) +{ + std::lock_guard lock(mutex); + if (auto it = threads.find(tid); it != threads.end()) + it->second->signalImpl(); +} + +void CancelToken::Registry::signal(UInt64 tid, int code, const String & message) +{ + std::lock_guard lock(mutex); + if (auto it = threads.find(tid); it != threads.end()) + it->second->signalImpl(code, message); +} + +CancelToken::Registry & CancelToken::Registry::instance() +{ + static Registry registry; + return registry; +} + +CancelToken::CancelToken() + : state(disabled) + , thread_id(getThreadId()) +{ + Registry::instance().insert(this); +} + +CancelToken::~CancelToken() +{ + Registry::instance().remove(this); +} + +void CancelToken::signal(UInt64 tid) +{ + Registry::instance().signal(tid); +} + +void CancelToken::signal(UInt64 tid, int code, const String & message) +{ + Registry::instance().signal(tid, code, message); +} + +bool CancelToken::wait(UInt32 * address, UInt32 value) +{ + chassert((reinterpret_cast(address) & canceled) == 0); // An `address` must be 2-byte aligned + if (value & signaled) // Can happen after spurious wake-up due to cancel of other thread + { + // static std::atomic x{0}; + // if (x++ > 5) + // sleep(3600); + return true; // Spin-wait unless signal is handled + } + + UInt64 s = state.load(); + while (true) + { + DBG("s={}", s); + if (s & disabled) + { + // Start non-cancellable wait on futex. Spurious wake-up is possible. + futexWait(address, value); + return true; // Disabled - true is forced + } + if (s & canceled) + return false; // Has already been canceled + if (state.compare_exchange_strong(s, reinterpret_cast(address))) + break; // This futex has been "acquired" by this token + } + + // Start cancellable wait. Spurious wake-up is possible. + DBG("start cancellable wait address={} value={}", static_cast(address), value); + futexWait(address, value); + + // "Release" futex and check for cancellation + s = state.load(); + while (true) + { + DBG("finish cancellable wait, s={}", s); + chassert((s & disabled) != disabled); // `disable()` must not be called from another thread + if (s & canceled) + { + if (s == canceled) + break; // Signaled; futex "release" has been done by the signaling thread + else + { + s = state.load(); + continue; // To avoid race (may lead to futex destruction) we have to wait for signaling thread to finish + } + } + if (state.compare_exchange_strong(s, 0)) + return true; // There was no cancellation; futex "released" + } + + // Reset signaled bit + reinterpret_cast *>(address)->fetch_and(~signaled); + return false; +} + +void CancelToken::raise() +{ + std::unique_lock lock(signal_mutex); + DBG("raise code={} msg={}", exception_code, exception_message); + if (exception_code != 0) + throw DB::Exception( + std::exchange(exception_code, 0), + std::exchange(exception_message, {})); + else + throw DB::Exception(ErrorCodes::THREAD_WAS_CANCELLED, "Thread was cancelled"); +} + +void CancelToken::notifyOne(UInt32 * address) +{ + futexWake(address, 1); +} + +void CancelToken::notifyAll(UInt32 * address) +{ + futexWake(address, INT_MAX); +} + +void CancelToken::signalImpl() +{ + signalImpl(0, {}); +} + +std::mutex CancelToken::signal_mutex; + +void CancelToken::signalImpl(int code, const String & message) +{ + // Serialize all signaling threads to avoid races due to concurrent signal()/raise() calls + std::unique_lock lock(signal_mutex); + + UInt64 s = state.load(); + while (true) + { + DBG("s={}", s); + if (s & canceled) + return; // Already cancelled - don't signal twice + if (state.compare_exchange_strong(s, s | canceled)) + break; // It is the cancelling thread - should deliver signal if necessary + } + + DBG("cancel tid={} code={} msg={}", thread_id, code, message); + exception_code = code; + exception_message = message; + + if ((s & disabled) == disabled) + return; // Cancellation is disabled - just signal token for later, but don't wake + std::atomic * address = reinterpret_cast *>(s & disabled); + DBG("address={}", static_cast(address)); + if (address == nullptr) + return; // Thread is currently not waiting on futex - wake-up not required + + // Set signaled bit + UInt32 value = address->load(); + while (true) + { + if (value & signaled) // Already signaled, just spin-wait until previous signal is handled by waiter + value = address->load(); + else if (address->compare_exchange_strong(value, value | signaled)) + break; + } + + // Wake all threads waiting on `address`, one of them will be cancelled and others will get spurious wake-ups + // Woken canceled thread will reset signaled bit + DBG("wake"); + futexWake(address, INT_MAX); + + // Signaling thread must remove address from state to notify canceled thread that `futexWake()` is done, thus `wake()` can return. + // Otherwise we may have race condition: signaling thread may try to wake futex that has been already destructed. + state.store(canceled); +} + +Cancellable::Cancellable() +{ + CancelToken::local().reset(); +} + +Cancellable::~Cancellable() +{ + CancelToken::local().disable(); +} + +NotCancellable::NotCancellable() +{ + CancelToken::local().disable(); +} + +NotCancellable::~NotCancellable() +{ + CancelToken::local().enable(); +} + +CancellableSharedMutex::CancellableSharedMutex() + : state(0) + , waiters(0) +{} + +void CancellableSharedMutex::lock() +{ + UInt64 value = state.load(); + while (true) + { + DBG("#A r={} w={} rs={} ws={}", value & readers, (value & writers) != 0, (value & readers_signaled) != 0, (value & writers_signaled) != 0); + if (value & writers) + { + waiters++; + if (!cancellableWaitUpperFetch(state, value)) + { + waiters--; + CancelToken::local().raise(); + } + else + waiters--; + } + else if (state.compare_exchange_strong(value, value | writers)) + break; + } + + value |= writers; + while (value & readers) + { + DBG("#B r={} w={} rs={} ws={}", value & readers, (value & writers) != 0, (value & readers_signaled) != 0, (value & writers_signaled) != 0); + if (!cancellableWaitLowerFetch(state, value)) + { + state.fetch_and(~writers); + wakeUpperAll(state); + CancelToken::local().raise(); + } + } +} + +bool CancellableSharedMutex::try_lock() +{ + UInt64 value = state.load(); + if ((value & (readers | writers)) == 0 && state.compare_exchange_strong(value, value | writers)) + return true; + return false; +} + +void CancellableSharedMutex::unlock() +{ + UInt64 value = state.fetch_and(~writers); + DBG("r={} w={} rs={} ws={}", value & readers, (value & writers) != 0, (value & readers_signaled) != 0, (value & writers_signaled) != 0); + if (waiters) + wakeUpperAll(state); +} + +void CancellableSharedMutex::lock_shared() +{ + UInt64 value = state.load(); + while (true) + { + DBG("r={} w={} rs={} ws={}", value & readers, (value & writers) != 0, (value & readers_signaled) != 0, (value & writers_signaled) != 0); + if (value & writers) + { + waiters++; + if (!cancellableWaitUpperFetch(state, value)) + { + waiters--; + CancelToken::local().raise(); + } + else + waiters--; + } + else if (state.compare_exchange_strong(value, value + 1)) // overflow is not realistic + break; + } +} + +bool CancellableSharedMutex::try_lock_shared() +{ + UInt64 value = state.load(); + if (!(value & writers) && state.compare_exchange_strong(value, value + 1)) // overflow is not realistic + return true; + return false; +} + +void CancellableSharedMutex::unlock_shared() +{ + UInt64 value = state.fetch_sub(1) - 1; + DBG("r={} w={} rs={} ws={}", value & readers, (value & writers) != 0, (value & readers_signaled) != 0, (value & writers_signaled) != 0); + if ((value & (writers | readers)) == writers) // If writer is waiting and no more readers + wakeLowerOne(state); // Wake writer +} + +FastSharedMutex::FastSharedMutex() + : state(0) + , waiters(0) +{} + +void FastSharedMutex::lock() +{ + UInt64 value = state.load(); + while (true) + { + if (value & writers) + { + waiters++; + waitUpperFetch(state, value); + waiters--; + } + else if (state.compare_exchange_strong(value, value | writers)) + break; + } + + value |= writers; + while (value & readers) + waitLowerFetch(state, value); +} + +bool FastSharedMutex::try_lock() +{ + UInt64 value = 0; + if (state.compare_exchange_strong(value, writers)) + return true; + return false; +} + +void FastSharedMutex::unlock() +{ + state.store(0); + if (waiters) + wakeUpperAll(state); +} + +void FastSharedMutex::lock_shared() +{ + UInt64 value = state.load(); + while (true) + { + if (value & writers) + { + waiters++; + waitUpperFetch(state, value); + waiters--; + } + else if (state.compare_exchange_strong(value, value + 1)) + break; + } +} + +bool FastSharedMutex::try_lock_shared() +{ + UInt64 value = state.load(); + if (!(value & writers) && state.compare_exchange_strong(value, value + 1)) + return true; + return false; +} + +void FastSharedMutex::unlock_shared() +{ + UInt64 value = state.fetch_sub(1) - 1; + if (value == writers) + wakeLowerOne(state); // Wake writer +} + +} + +#else + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int THREAD_WAS_CANCELLED; +} + +void CancelToken::raise() +{ + throw DB::Exception(ErrorCodes::THREAD_WAS_CANCELLED, "Thread was cancelled"); +} + +} + +#endif diff --git a/src/Common/Threading.h b/src/Common/Threading.h new file mode 100644 index 00000000000..14743def476 --- /dev/null +++ b/src/Common/Threading.h @@ -0,0 +1,282 @@ +#pragma once + +#include +#include + +#ifdef OS_LINUX /// Because of futex + +#include +#include +#include + + +// TODO(serxa): for debug only, remove it +#if 0 +#include +#include +#define DBG(...) std::cout << fmt::format("\033[01;3{}m[{}] {} {} {}\033[00m {}:{}\n", 1 + getThreadId() % 8, getThreadId(), reinterpret_cast(this), fmt::format(__VA_ARGS__), __PRETTY_FUNCTION__, __FILE__, __LINE__) +#else +#include +#define DBG(...) UNUSED(__VA_ARGS__) +#endif + +namespace DB +{ + +// Scoped object, enabling thread cancellation (cannot be nested) +struct Cancellable +{ + Cancellable(); + ~Cancellable(); +}; + +// Scoped object, disabling thread cancellation (cannot be nested; must be inside `Cancellable` region) +struct NotCancellable +{ + NotCancellable(); + ~NotCancellable(); +}; + +// Responsible for synchronization needed to deliver thread cancellation signal. +// Basic building block for cancallable synchronization primitives. +// Allows to perform cancellable wait on memory addresses (think futex) +class CancelToken +{ +public: + CancelToken(); + CancelToken(const CancelToken &) = delete; + CancelToken(CancelToken &&) = delete; + CancelToken & operator=(const CancelToken &) = delete; + ~CancelToken(); + + // Returns token for the current thread + static CancelToken & local() + { + static thread_local CancelToken token; + return token; + } + + // Cancellable wait on memory address (futex word). + // Thread will do atomic compare-and-sleep `*address == value`. Waiting will continue until `notify_one()` + // or `notify_all()` will be called with the same `address` or calling thread will be canceled using `signal()`. + // Note that spurious wake-ups are also possible due to cancellation of other waiters on the same `address`. + // WARNING: `address` must be 2-byte aligned and `value` highest bit must be zero. + // Return value: + // true - woken by either notify or spurious wakeup; + // false - iff cancelation signal has been received. + // Implementation details: + // It registers `address` inside token's `state` to allow other threads to wake this thread and deliver cancellation signal. + // Highest bit of `*address` is used for guarantied delivery of the signal, but is guaranteed to be zero on return due to cancellation. + // Intented to be called only by thread associated with this token. + bool wait(UInt32 * address, UInt32 value); + + // Throws `DB::Exception` received from `signal()`. Call it if `wait()` returned false. + // Intented to be called only by thread associated with this token. + [[noreturn]] void raise(); + + // Regular wake by address (futex word). It does not interact with token in any way. We have it here to complement `wait()`. + // Can be called from any thread. + static void notifyOne(UInt32 * address); + static void notifyAll(UInt32 * address); + + // Send cancel signal to thread with specified `tid`. + // If thread was waiting using `wait()` it will be woken up (unless cancellation is disabled). + // Can be called from any thread. + static void signal(UInt64 tid); + static void signal(UInt64 tid, int code, const String & message); + + // Flag used to deliver cancellation into memory address to wake a thread. + // Note that most significat bit at `addresses` to be used with `wait()` is reserved. + static constexpr UInt32 signaled = 1u << 31u; + +private: + friend struct Cancellable; + friend struct NotCancellable; + + // Restores initial state for token to be reused. See `Cancellable` struct. + // Intented to be called only by thread associated with this token. + void reset() + { + state.store(0); + } + + // Enable thread cancellation. See `NotCancellable` struct. + // Intented to be called only by thread associated with this token. + void enable() + { + chassert((state.load() & disabled) == disabled); + state.fetch_and(~disabled); + } + + // Disable thread cancellation. See `NotCancellable` struct. + // Intented to be called only by thread associated with this token. + void disable() + { + chassert((state.load() & disabled) == 0); + state.fetch_or(disabled); + } + + // Singleton. Maps thread IDs to tokens. + struct Registry; + friend struct Registry; + struct Registry + { + Registry(); + + std::mutex mutex; + std::unordered_map threads; // By thread ID + + void insert(CancelToken * token); + void remove(CancelToken * token); + void signal(UInt64 tid); + void signal(UInt64 tid, int code, const String & message); + + static Registry & instance(); + }; + + // Cancels this token and wakes thread if necessary. + // Can be called from any thread. + void signalImpl(); + void signalImpl(int code, const String & message); + + // Lower bit: cancel signal received flag + static constexpr UInt64 canceled = 1; + + // Upper bits - possible values: + // 1) all zeros: token is enabed, i.e. wait() call can return false, thread is not waiting on any address; + // 2) all ones: token is disabled, i.e. wait() call cannot be cancelled; + // 3) specific `address`: token is enabled and thread is currently waiting on this `address`. + static constexpr UInt64 disabled = ~canceled; + static_assert(sizeof(UInt32 *) == sizeof(UInt64)); // State must be able to hold an address + + // All signal handling logic should be globally serialized using this mutex + static std::mutex signal_mutex; + + // Cancellation state + alignas(64) std::atomic state; + [[maybe_unused]] char padding[64 - sizeof(state)]; + + // Cancellation exception + int exception_code; + String exception_message; + + // Token is permanently attached to a single thread. There is one-to-one mapping between threads and tokens. + const UInt64 thread_id; +}; + +class CancellableSharedMutex +{ +public: + CancellableSharedMutex(); + ~CancellableSharedMutex() = default; + CancellableSharedMutex(const CancellableSharedMutex &) = delete; + CancellableSharedMutex & operator=(const CancellableSharedMutex &) = delete; + + // Exclusive ownership + void lock(); + bool try_lock(); + void unlock(); + + // Shared ownership + void lock_shared(); + bool try_lock_shared(); + void unlock_shared(); + +private: + // State 64-bits layout: + // 1b - 31b - 1b - 31b + // signaled - writers - signaled - readers + // 63------------------------------------0 + // Two 32-bit words are used for cancellable waiting, so each has its own separate signaled bit + static constexpr UInt64 readers = (1ull << 32ull) - 1ull - CancelToken::signaled; + static constexpr UInt64 readers_signaled = CancelToken::signaled; + static constexpr UInt64 writers = readers << 32ull; + static constexpr UInt64 writers_signaled = readers_signaled << 32ull; + + alignas(64) std::atomic state; + std::atomic waiters; +}; + +class FastSharedMutex +{ +public: + FastSharedMutex(); + ~FastSharedMutex() = default; + FastSharedMutex(const FastSharedMutex &) = delete; + FastSharedMutex & operator=(const FastSharedMutex &) = delete; + + // Exclusive ownership + void lock(); + bool try_lock(); + void unlock(); + + // Shared ownership + void lock_shared(); + bool try_lock_shared(); + void unlock_shared(); + +private: + static constexpr UInt64 readers = (1ull << 32ull) - 1ull; // Lower 32 bits of state + static constexpr UInt64 writers = ~readers; // Upper 32 bits of state + + alignas(64) std::atomic state; + std::atomic waiters; +}; + +} + +#else + +#include + +// WARNING: We support cancellable synchronization primitives only on linux for now + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int THREAD_WAS_CANCELLED; +} + +struct Cancellable +{ + Cancellable() = default; + ~Cancellable() = default; +}; + +struct NotCancellable +{ + NotCancellable() = default; + ~NotCancellable() = default; +}; + +class CancelToken +{ +public: + CancelToken() = default; + CancelToken(const CancelToken &) = delete; + CancelToken(CancelToken &&) = delete; + CancelToken & operator=(const CancelToken &) = delete; + ~CancelToken() = default; + + static CancelToken & local() + { + static CancelToken token; + return token; + } + + bool wait(UInt32 *, UInt32) { return true; } + [[noreturn]] void raise(); + static void notifyOne(UInt32 *) {} + static void notifyAll(UInt32 *) {} + static void signal(UInt64) {} + static void signal(UInt64, int, const String &) {} +}; + +using CancellableSharedMutex = std::shared_mutex; +using FastSharedMutex = std::shared_mutex; + +} + +#endif diff --git a/src/Common/tests/gtest_threading.cpp b/src/Common/tests/gtest_threading.cpp new file mode 100644 index 00000000000..d9cb8748eeb --- /dev/null +++ b/src/Common/tests/gtest_threading.cpp @@ -0,0 +1,369 @@ +#include + +#include +#include +#include +#include +#include + +#include "Common/Exception.h" +#include +#include + +#include +#include + + +namespace DB +{ + namespace ErrorCodes + { + extern const int THREAD_WAS_CANCELLED; + } +} + +struct NoCancel {}; + +// for all PerfTests +static constexpr int requests = 512 * 1024; +static constexpr int max_threads = 16; + +template +void TestSharedMutex() +{ + // Test multiple readers can acquire lock + for (int readers = 1; readers <= 128; readers *= 2) + { + T sm; + std::atomic test(0); + std::barrier sync(readers + 1); + + std::vector threads; + threads.reserve(readers); + auto reader = [&] + { + [[maybe_unused]] Status status; + std::shared_lock lock(sm); + test++; + sync.arrive_and_wait(); + }; + + for (int i = 0; i < readers; i++) + threads.emplace_back(reader); + + { // writer + [[maybe_unused]] Status status; + sync.arrive_and_wait(); // wait for all reader to acquire lock to avoid blocking them + std::unique_lock lock(sm); + test++; + } + + for (auto & thread : threads) + thread.join(); + + ASSERT_EQ(test, readers + 1); + } + + // Test multiple writers cannot acquire lock simultaneously + for (int writers = 1; writers <= 128; writers *= 2) + { + T sm; + int test = 0; + std::barrier sync(writers); + std::vector threads; + + threads.reserve(writers); + auto writer = [&] + { + [[maybe_unused]] Status status; + sync.arrive_and_wait(); + std::unique_lock lock(sm); + test++; + }; + + for (int i = 0; i < writers; i++) + threads.emplace_back(writer); + + for (auto & thread : threads) + thread.join(); + + ASSERT_EQ(test, writers); + } +} + +template +void TestSharedMutexCancelReader() +{ + constexpr int readers = 8; + constexpr int tasks_per_reader = 32; + + T sm; + std::atomic successes(0); + std::atomic cancels(0); + std::barrier sync(readers + 1); + std::barrier cancel_sync(readers / 2 + 1); + std::vector threads; + + std::mutex m; + std::vector tids_to_cancel; + + threads.reserve(readers); + auto reader = [&] (int reader_id) + { + if (reader_id % 2 == 0) + { + std::unique_lock lock(m); + tids_to_cancel.emplace_back(getThreadId()); + } + for (int task = 0; task < tasks_per_reader; task++) { + try + { + [[maybe_unused]] Status status; + sync.arrive_and_wait(); // (A) sync with writer + sync.arrive_and_wait(); // (B) wait for writer to acquire unique_lock + std::shared_lock lock(sm); + successes++; + } + catch(DB::Exception & e) + { + ASSERT_EQ(e.code(), DB::ErrorCodes::THREAD_WAS_CANCELLED); + ASSERT_EQ(e.message(), "test"); + cancels++; + cancel_sync.arrive_and_wait(); // (C) sync with writer + } + } + }; + + for (int reader_id = 0; reader_id < readers; reader_id++) + threads.emplace_back(reader, reader_id); + + { // writer + [[maybe_unused]] Status status; + for (int task = 0; task < tasks_per_reader; task++) { + sync.arrive_and_wait(); // (A) wait for readers to finish previous task + ASSERT_EQ(cancels + successes, task * readers); + ASSERT_EQ(cancels, task * readers / 2); + ASSERT_EQ(successes, task * readers / 2); + std::unique_lock lock(sm); + sync.arrive_and_wait(); // (B) sync with readers + //std::unique_lock lock(m); // not needed, already synced using barrier + for (UInt64 tid : tids_to_cancel) + DB::CancelToken::signal(tid, DB::ErrorCodes::THREAD_WAS_CANCELLED, "test"); + + // This sync is crutial. It is needed to hold `lock` long enough. + // It guarantees that every cancelled thread will find `sm` blocked by writer, and thus will begin to wait. + // Wait() call is required for cancellation. Otherwise, fastpath acquire w/o wait will not generate exception. + // And this is the desired behaviour. + cancel_sync.arrive_and_wait(); // (C) wait for cancellation to finish, before unlock. + } + } + + for (auto & thread : threads) + thread.join(); + + ASSERT_EQ(successes, tasks_per_reader * readers / 2); + ASSERT_EQ(cancels, tasks_per_reader * readers / 2); +} + +template +void TestSharedMutexCancelWriter() +{ + constexpr int writers = 8; + constexpr int tasks_per_writer = 32; + + T sm; + std::atomic successes(0); + std::atomic cancels(0); + std::barrier sync(writers); + std::vector threads; + + std::mutex m; + std::vector all_tids; + + threads.reserve(writers); + auto writer = [&] + { + { + std::unique_lock lock(m); + all_tids.emplace_back(getThreadId()); + } + for (int task = 0; task < tasks_per_writer; task++) { + try + { + [[maybe_unused]] Status status; + sync.arrive_and_wait(); // (A) sync all threads before race to acquire the lock + std::unique_lock lock(sm); + successes++; + // Thread that managed to acquire the lock cancels all other waiting writers + //std::unique_lock lock(m); // not needed, already synced using barrier + for (UInt64 tid : all_tids) + { + if (tid != getThreadId()) + DB::CancelToken::signal(tid, DB::ErrorCodes::THREAD_WAS_CANCELLED, "test"); + } + + // This sync is crutial. It is needed to hold `lock` long enough. + // It guarantees that every cancelled thread will find `sm` blocked, and thus will begin to wait. + // Wait() call is required for cancellation. Otherwise, fastpath acquire w/o wait will not generate exception. + // And this is the desired behaviour. + sync.arrive_and_wait(); // (B) wait for cancellation to finish, before unlock. + } + catch(DB::Exception & e) + { + ASSERT_EQ(e.code(), DB::ErrorCodes::THREAD_WAS_CANCELLED); + ASSERT_EQ(e.message(), "test"); + cancels++; + sync.arrive_and_wait(); // (B) sync with race winner + } + } + }; + + for (int writer_id = 0; writer_id < writers; writer_id++) + threads.emplace_back(writer); + + for (auto & thread : threads) + thread.join(); + + ASSERT_EQ(successes, tasks_per_writer); + ASSERT_EQ(cancels, tasks_per_writer * (writers - 1)); +} + +template +void PerfTestSharedMutexReadersOnly() +{ + std::cout << "*** " << demangle(typeid(T).name()) << "/" << demangle(typeid(Status).name()) << " ***" << std::endl; + + for (int thrs = 1; thrs <= max_threads; thrs *= 2) + { + T sm; + std::vector threads; + threads.reserve(thrs); + auto reader = [&] + { + [[maybe_unused]] Status status; + for (int request = requests / thrs; request; request--) + { + std::shared_lock lock(sm); + } + }; + + Stopwatch watch; + for (int i = 0; i < thrs; i++) + threads.emplace_back(reader); + + for (auto & thread : threads) + thread.join(); + + double ns = watch.elapsedNanoseconds(); + std::cout << "thrs = " << thrs << ":\t" << ns / requests << " ns\t" << requests * 1e9 / ns << " rps" << std::endl; + } +} + +template +void PerfTestSharedMutexWritersOnly() +{ + std::cout << "*** " << demangle(typeid(T).name()) << "/" << demangle(typeid(Status).name()) << " ***" << std::endl; + + for (int thrs = 1; thrs <= max_threads; thrs *= 2) + { + int counter = 0; + T sm; + std::vector threads; + threads.reserve(thrs); + auto writer = [&] + { + [[maybe_unused]] Status status; + for (int request = requests / thrs; request; request--) + { + std::unique_lock lock(sm); + ASSERT_TRUE(counter % 2 == 0); + counter++; + std::atomic_signal_fence(std::memory_order::seq_cst); // force complier to generate two separate increment instructions + counter++; + } + }; + + Stopwatch watch; + for (int i = 0; i < thrs; i++) + threads.emplace_back(writer); + + for (auto & thread : threads) + thread.join(); + + ASSERT_EQ(counter, requests * 2); + + double ns = watch.elapsedNanoseconds(); + std::cout << "thrs = " << thrs << ":\t" << ns / requests << " ns\t" << requests * 1e9 / ns << " rps" << std::endl; + } +} + +template +void PerfTestSharedMutexRW() +{ + std::cout << "*** " << demangle(typeid(T).name()) << "/" << demangle(typeid(Status).name()) << " ***" << std::endl; + + for (int thrs = 1; thrs <= max_threads; thrs *= 2) + { + int counter = 0; + T sm; + std::vector threads; + threads.reserve(thrs); + auto reader = [&] + { + [[maybe_unused]] Status status; + for (int request = requests / thrs / 2; request; request--) + { + { + std::shared_lock lock(sm); + ASSERT_TRUE(counter % 2 == 0); + } + { + std::unique_lock lock(sm); + ASSERT_TRUE(counter % 2 == 0); + counter++; + std::atomic_signal_fence(std::memory_order::seq_cst); // force complier to generate two separate increment instructions + counter++; + } + } + }; + + Stopwatch watch; + for (int i = 0; i < thrs; i++) + threads.emplace_back(reader); + + for (auto & thread : threads) + thread.join(); + + ASSERT_EQ(counter, requests); + + double ns = watch.elapsedNanoseconds(); + std::cout << "thrs = " << thrs << ":\t" << ns / requests << " ns\t" << requests * 1e9 / ns << " rps" << std::endl; + } +} + +TEST(Threading, SharedMutexSmokeCancellableEnabled) { TestSharedMutex(); } +TEST(Threading, SharedMutexSmokeCancellableDisabled) { TestSharedMutex(); } +TEST(Threading, SharedMutexSmokeFast) { TestSharedMutex(); } +TEST(Threading, SharedMutexSmokeStd) { TestSharedMutex(); } + +TEST(Threading, PerfTestSharedMutexReadersOnlyCancellableEnabled) { PerfTestSharedMutexReadersOnly(); } +TEST(Threading, PerfTestSharedMutexReadersOnlyCancellableDisabled) { PerfTestSharedMutexReadersOnly(); } +TEST(Threading, PerfTestSharedMutexReadersOnlyFast) { PerfTestSharedMutexReadersOnly(); } +TEST(Threading, PerfTestSharedMutexReadersOnlyStd) { PerfTestSharedMutexReadersOnly(); } + +TEST(Threading, PerfTestSharedMutexWritersOnlyCancellableEnabled) { PerfTestSharedMutexWritersOnly(); } +TEST(Threading, PerfTestSharedMutexWritersOnlyCancellableDisabled) { PerfTestSharedMutexWritersOnly(); } +TEST(Threading, PerfTestSharedMutexWritersOnlyFast) { PerfTestSharedMutexWritersOnly(); } +TEST(Threading, PerfTestSharedMutexWritersOnlyStd) { PerfTestSharedMutexWritersOnly(); } + +TEST(Threading, PerfTestSharedMutexRWCancellableEnabled) { PerfTestSharedMutexRW(); } +TEST(Threading, PerfTestSharedMutexRWCancellableDisabled) { PerfTestSharedMutexRW(); } +TEST(Threading, PerfTestSharedMutexRWFast) { PerfTestSharedMutexRW(); } +TEST(Threading, PerfTestSharedMutexRWStd) { PerfTestSharedMutexRW(); } + +#ifdef OS_LINUX /// These tests require cancellability + +TEST(Threading, SharedMutexCancelReaderCancellableEnabled) { TestSharedMutexCancelReader(); } +TEST(Threading, SharedMutexCancelWriterCancellableEnabled) { TestSharedMutexCancelWriter(); } + +#endif