mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-12-03 13:02:00 +00:00
add fast and cancellable shared_mutex alternatives
This commit is contained in:
parent
ac162a2c49
commit
baf6297f1d
@ -643,6 +643,7 @@
|
||||
M(672, INVALID_SCHEDULER_NODE) \
|
||||
M(673, RESOURCE_ACCESS_DENIED) \
|
||||
M(674, RESOURCE_NOT_FOUND) \
|
||||
M(675, THREAD_WAS_CANCELLED) \
|
||||
\
|
||||
M(999, KEEPER_EXCEPTION) \
|
||||
M(1000, POCO_EXCEPTION) \
|
||||
|
511
src/Common/Threading.cpp
Normal file
511
src/Common/Threading.cpp
Normal file
@ -0,0 +1,511 @@
|
||||
#include <Common/Threading.h>
|
||||
#include <Common/Exception.h>
|
||||
|
||||
#ifdef OS_LINUX /// Because of 'sigqueue' functions, RT signals and futex.
|
||||
|
||||
#include <base/getThreadId.h>
|
||||
|
||||
#include <bit>
|
||||
|
||||
#include <linux/futex.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/syscall.h>
|
||||
#include <unistd.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int THREAD_WAS_CANCELLED;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
inline long futexWait(void * address, UInt32 value)
|
||||
{
|
||||
return syscall(SYS_futex, address, FUTEX_WAIT_PRIVATE, value, nullptr, nullptr, 0);
|
||||
}
|
||||
|
||||
inline long futexWake(void * address, int count)
|
||||
{
|
||||
return syscall(SYS_futex, address, FUTEX_WAKE_PRIVATE, count, nullptr, nullptr, 0);
|
||||
}
|
||||
|
||||
// inline void waitFetch(std::atomic<UInt32> & address, UInt32 & value)
|
||||
// {
|
||||
// futexWait(&address, value);
|
||||
// value = address.load();
|
||||
// }
|
||||
|
||||
// inline void wakeOne(std::atomic<UInt32> & address)
|
||||
// {
|
||||
// futexWake(&address, 1);
|
||||
// }
|
||||
|
||||
// inline void wakeAll(std::atomic<UInt32> & address)
|
||||
// {
|
||||
// futexWake(&address, INT_MAX);
|
||||
// }
|
||||
|
||||
inline constexpr UInt32 lowerValue(UInt64 value)
|
||||
{
|
||||
return UInt32(value & 0xffffffffull);
|
||||
}
|
||||
|
||||
inline constexpr UInt32 upperValue(UInt64 value)
|
||||
{
|
||||
return UInt32(value >> 32ull);
|
||||
}
|
||||
|
||||
inline UInt32 * lowerAddress(void * address)
|
||||
{
|
||||
return reinterpret_cast<UInt32 *>(address) + (std::endian::native == std::endian::big);
|
||||
}
|
||||
|
||||
inline UInt32 * upperAddress(void * address)
|
||||
{
|
||||
return reinterpret_cast<UInt32 *>(address) + (std::endian::native == std::endian::little);
|
||||
}
|
||||
|
||||
inline void waitLowerFetch(std::atomic<UInt64> & address, UInt64 & value)
|
||||
{
|
||||
futexWait(lowerAddress(&address), lowerValue(value));
|
||||
value = address.load();
|
||||
}
|
||||
|
||||
inline bool cancellableWaitLowerFetch(std::atomic<UInt64> & address, UInt64 & value)
|
||||
{
|
||||
bool res = CancelToken::local().wait(lowerAddress(&address), lowerValue(value));
|
||||
value = address.load();
|
||||
return res;
|
||||
}
|
||||
|
||||
inline void wakeLowerOne(std::atomic<UInt64> & address)
|
||||
{
|
||||
syscall(SYS_futex, lowerAddress(&address), FUTEX_WAKE_PRIVATE, 1, nullptr, nullptr, 0);
|
||||
}
|
||||
|
||||
// inline void wakeLowerAll(std::atomic<UInt64> & address)
|
||||
// {
|
||||
// syscall(SYS_futex, lowerAddress(&address), FUTEX_WAKE_PRIVATE, INT_MAX, nullptr, nullptr, 0);
|
||||
// }
|
||||
|
||||
inline void waitUpperFetch(std::atomic<UInt64> & address, UInt64 & value)
|
||||
{
|
||||
futexWait(upperAddress(&address), upperValue(value));
|
||||
value = address.load();
|
||||
}
|
||||
|
||||
inline bool cancellableWaitUpperFetch(std::atomic<UInt64> & address, UInt64 & value)
|
||||
{
|
||||
bool res = CancelToken::local().wait(upperAddress(&address), upperValue(value));
|
||||
value = address.load();
|
||||
return res;
|
||||
}
|
||||
|
||||
// inline void wakeUpperOne(std::atomic<UInt64> & address)
|
||||
// {
|
||||
// syscall(SYS_futex, upperAddress(&address), FUTEX_WAKE_PRIVATE, 1, nullptr, nullptr, 0);
|
||||
// }
|
||||
|
||||
inline void wakeUpperAll(std::atomic<UInt64> & address)
|
||||
{
|
||||
syscall(SYS_futex, upperAddress(&address), FUTEX_WAKE_PRIVATE, INT_MAX, nullptr, nullptr, 0);
|
||||
}
|
||||
}
|
||||
|
||||
CancelToken::Registry::Registry()
|
||||
{
|
||||
// setupCancelSignalHandler();
|
||||
}
|
||||
|
||||
void CancelToken::Registry::insert(CancelToken * token)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
threads[token->thread_id] = token;
|
||||
}
|
||||
|
||||
void CancelToken::Registry::remove(CancelToken * token)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
threads.erase(token->thread_id);
|
||||
}
|
||||
|
||||
void CancelToken::Registry::signal(UInt64 tid)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (auto it = threads.find(tid); it != threads.end())
|
||||
it->second->signalImpl();
|
||||
}
|
||||
|
||||
void CancelToken::Registry::signal(UInt64 tid, int code, const String & message)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (auto it = threads.find(tid); it != threads.end())
|
||||
it->second->signalImpl(code, message);
|
||||
}
|
||||
|
||||
CancelToken::Registry & CancelToken::Registry::instance()
|
||||
{
|
||||
static Registry registry;
|
||||
return registry;
|
||||
}
|
||||
|
||||
CancelToken::CancelToken()
|
||||
: state(disabled)
|
||||
, thread_id(getThreadId())
|
||||
{
|
||||
Registry::instance().insert(this);
|
||||
}
|
||||
|
||||
CancelToken::~CancelToken()
|
||||
{
|
||||
Registry::instance().remove(this);
|
||||
}
|
||||
|
||||
void CancelToken::signal(UInt64 tid)
|
||||
{
|
||||
Registry::instance().signal(tid);
|
||||
}
|
||||
|
||||
void CancelToken::signal(UInt64 tid, int code, const String & message)
|
||||
{
|
||||
Registry::instance().signal(tid, code, message);
|
||||
}
|
||||
|
||||
bool CancelToken::wait(UInt32 * address, UInt32 value)
|
||||
{
|
||||
chassert((reinterpret_cast<UInt64>(address) & canceled) == 0); // An `address` must be 2-byte aligned
|
||||
if (value & signaled) // Can happen after spurious wake-up due to cancel of other thread
|
||||
{
|
||||
// static std::atomic<int> x{0};
|
||||
// if (x++ > 5)
|
||||
// sleep(3600);
|
||||
return true; // Spin-wait unless signal is handled
|
||||
}
|
||||
|
||||
UInt64 s = state.load();
|
||||
while (true)
|
||||
{
|
||||
DBG("s={}", s);
|
||||
if (s & disabled)
|
||||
{
|
||||
// Start non-cancellable wait on futex. Spurious wake-up is possible.
|
||||
futexWait(address, value);
|
||||
return true; // Disabled - true is forced
|
||||
}
|
||||
if (s & canceled)
|
||||
return false; // Has already been canceled
|
||||
if (state.compare_exchange_strong(s, reinterpret_cast<UInt64>(address)))
|
||||
break; // This futex has been "acquired" by this token
|
||||
}
|
||||
|
||||
// Start cancellable wait. Spurious wake-up is possible.
|
||||
DBG("start cancellable wait address={} value={}", static_cast<void*>(address), value);
|
||||
futexWait(address, value);
|
||||
|
||||
// "Release" futex and check for cancellation
|
||||
s = state.load();
|
||||
while (true)
|
||||
{
|
||||
DBG("finish cancellable wait, s={}", s);
|
||||
chassert((s & disabled) != disabled); // `disable()` must not be called from another thread
|
||||
if (s & canceled)
|
||||
{
|
||||
if (s == canceled)
|
||||
break; // Signaled; futex "release" has been done by the signaling thread
|
||||
else
|
||||
{
|
||||
s = state.load();
|
||||
continue; // To avoid race (may lead to futex destruction) we have to wait for signaling thread to finish
|
||||
}
|
||||
}
|
||||
if (state.compare_exchange_strong(s, 0))
|
||||
return true; // There was no cancellation; futex "released"
|
||||
}
|
||||
|
||||
// Reset signaled bit
|
||||
reinterpret_cast<std::atomic<UInt32> *>(address)->fetch_and(~signaled);
|
||||
return false;
|
||||
}
|
||||
|
||||
void CancelToken::raise()
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(signal_mutex);
|
||||
DBG("raise code={} msg={}", exception_code, exception_message);
|
||||
if (exception_code != 0)
|
||||
throw DB::Exception(
|
||||
std::exchange(exception_code, 0),
|
||||
std::exchange(exception_message, {}));
|
||||
else
|
||||
throw DB::Exception(ErrorCodes::THREAD_WAS_CANCELLED, "Thread was cancelled");
|
||||
}
|
||||
|
||||
void CancelToken::notifyOne(UInt32 * address)
|
||||
{
|
||||
futexWake(address, 1);
|
||||
}
|
||||
|
||||
void CancelToken::notifyAll(UInt32 * address)
|
||||
{
|
||||
futexWake(address, INT_MAX);
|
||||
}
|
||||
|
||||
void CancelToken::signalImpl()
|
||||
{
|
||||
signalImpl(0, {});
|
||||
}
|
||||
|
||||
std::mutex CancelToken::signal_mutex;
|
||||
|
||||
void CancelToken::signalImpl(int code, const String & message)
|
||||
{
|
||||
// Serialize all signaling threads to avoid races due to concurrent signal()/raise() calls
|
||||
std::unique_lock<std::mutex> lock(signal_mutex);
|
||||
|
||||
UInt64 s = state.load();
|
||||
while (true)
|
||||
{
|
||||
DBG("s={}", s);
|
||||
if (s & canceled)
|
||||
return; // Already cancelled - don't signal twice
|
||||
if (state.compare_exchange_strong(s, s | canceled))
|
||||
break; // It is the cancelling thread - should deliver signal if necessary
|
||||
}
|
||||
|
||||
DBG("cancel tid={} code={} msg={}", thread_id, code, message);
|
||||
exception_code = code;
|
||||
exception_message = message;
|
||||
|
||||
if ((s & disabled) == disabled)
|
||||
return; // Cancellation is disabled - just signal token for later, but don't wake
|
||||
std::atomic<UInt32> * address = reinterpret_cast<std::atomic<UInt32> *>(s & disabled);
|
||||
DBG("address={}", static_cast<void*>(address));
|
||||
if (address == nullptr)
|
||||
return; // Thread is currently not waiting on futex - wake-up not required
|
||||
|
||||
// Set signaled bit
|
||||
UInt32 value = address->load();
|
||||
while (true)
|
||||
{
|
||||
if (value & signaled) // Already signaled, just spin-wait until previous signal is handled by waiter
|
||||
value = address->load();
|
||||
else if (address->compare_exchange_strong(value, value | signaled))
|
||||
break;
|
||||
}
|
||||
|
||||
// Wake all threads waiting on `address`, one of them will be cancelled and others will get spurious wake-ups
|
||||
// Woken canceled thread will reset signaled bit
|
||||
DBG("wake");
|
||||
futexWake(address, INT_MAX);
|
||||
|
||||
// Signaling thread must remove address from state to notify canceled thread that `futexWake()` is done, thus `wake()` can return.
|
||||
// Otherwise we may have race condition: signaling thread may try to wake futex that has been already destructed.
|
||||
state.store(canceled);
|
||||
}
|
||||
|
||||
Cancellable::Cancellable()
|
||||
{
|
||||
CancelToken::local().reset();
|
||||
}
|
||||
|
||||
Cancellable::~Cancellable()
|
||||
{
|
||||
CancelToken::local().disable();
|
||||
}
|
||||
|
||||
NotCancellable::NotCancellable()
|
||||
{
|
||||
CancelToken::local().disable();
|
||||
}
|
||||
|
||||
NotCancellable::~NotCancellable()
|
||||
{
|
||||
CancelToken::local().enable();
|
||||
}
|
||||
|
||||
CancellableSharedMutex::CancellableSharedMutex()
|
||||
: state(0)
|
||||
, waiters(0)
|
||||
{}
|
||||
|
||||
void CancellableSharedMutex::lock()
|
||||
{
|
||||
UInt64 value = state.load();
|
||||
while (true)
|
||||
{
|
||||
DBG("#A r={} w={} rs={} ws={}", value & readers, (value & writers) != 0, (value & readers_signaled) != 0, (value & writers_signaled) != 0);
|
||||
if (value & writers)
|
||||
{
|
||||
waiters++;
|
||||
if (!cancellableWaitUpperFetch(state, value))
|
||||
{
|
||||
waiters--;
|
||||
CancelToken::local().raise();
|
||||
}
|
||||
else
|
||||
waiters--;
|
||||
}
|
||||
else if (state.compare_exchange_strong(value, value | writers))
|
||||
break;
|
||||
}
|
||||
|
||||
value |= writers;
|
||||
while (value & readers)
|
||||
{
|
||||
DBG("#B r={} w={} rs={} ws={}", value & readers, (value & writers) != 0, (value & readers_signaled) != 0, (value & writers_signaled) != 0);
|
||||
if (!cancellableWaitLowerFetch(state, value))
|
||||
{
|
||||
state.fetch_and(~writers);
|
||||
wakeUpperAll(state);
|
||||
CancelToken::local().raise();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool CancellableSharedMutex::try_lock()
|
||||
{
|
||||
UInt64 value = state.load();
|
||||
if ((value & (readers | writers)) == 0 && state.compare_exchange_strong(value, value | writers))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
void CancellableSharedMutex::unlock()
|
||||
{
|
||||
UInt64 value = state.fetch_and(~writers);
|
||||
DBG("r={} w={} rs={} ws={}", value & readers, (value & writers) != 0, (value & readers_signaled) != 0, (value & writers_signaled) != 0);
|
||||
if (waiters)
|
||||
wakeUpperAll(state);
|
||||
}
|
||||
|
||||
void CancellableSharedMutex::lock_shared()
|
||||
{
|
||||
UInt64 value = state.load();
|
||||
while (true)
|
||||
{
|
||||
DBG("r={} w={} rs={} ws={}", value & readers, (value & writers) != 0, (value & readers_signaled) != 0, (value & writers_signaled) != 0);
|
||||
if (value & writers)
|
||||
{
|
||||
waiters++;
|
||||
if (!cancellableWaitUpperFetch(state, value))
|
||||
{
|
||||
waiters--;
|
||||
CancelToken::local().raise();
|
||||
}
|
||||
else
|
||||
waiters--;
|
||||
}
|
||||
else if (state.compare_exchange_strong(value, value + 1)) // overflow is not realistic
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool CancellableSharedMutex::try_lock_shared()
|
||||
{
|
||||
UInt64 value = state.load();
|
||||
if (!(value & writers) && state.compare_exchange_strong(value, value + 1)) // overflow is not realistic
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
void CancellableSharedMutex::unlock_shared()
|
||||
{
|
||||
UInt64 value = state.fetch_sub(1) - 1;
|
||||
DBG("r={} w={} rs={} ws={}", value & readers, (value & writers) != 0, (value & readers_signaled) != 0, (value & writers_signaled) != 0);
|
||||
if ((value & (writers | readers)) == writers) // If writer is waiting and no more readers
|
||||
wakeLowerOne(state); // Wake writer
|
||||
}
|
||||
|
||||
FastSharedMutex::FastSharedMutex()
|
||||
: state(0)
|
||||
, waiters(0)
|
||||
{}
|
||||
|
||||
void FastSharedMutex::lock()
|
||||
{
|
||||
UInt64 value = state.load();
|
||||
while (true)
|
||||
{
|
||||
if (value & writers)
|
||||
{
|
||||
waiters++;
|
||||
waitUpperFetch(state, value);
|
||||
waiters--;
|
||||
}
|
||||
else if (state.compare_exchange_strong(value, value | writers))
|
||||
break;
|
||||
}
|
||||
|
||||
value |= writers;
|
||||
while (value & readers)
|
||||
waitLowerFetch(state, value);
|
||||
}
|
||||
|
||||
bool FastSharedMutex::try_lock()
|
||||
{
|
||||
UInt64 value = 0;
|
||||
if (state.compare_exchange_strong(value, writers))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
void FastSharedMutex::unlock()
|
||||
{
|
||||
state.store(0);
|
||||
if (waiters)
|
||||
wakeUpperAll(state);
|
||||
}
|
||||
|
||||
void FastSharedMutex::lock_shared()
|
||||
{
|
||||
UInt64 value = state.load();
|
||||
while (true)
|
||||
{
|
||||
if (value & writers)
|
||||
{
|
||||
waiters++;
|
||||
waitUpperFetch(state, value);
|
||||
waiters--;
|
||||
}
|
||||
else if (state.compare_exchange_strong(value, value + 1))
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool FastSharedMutex::try_lock_shared()
|
||||
{
|
||||
UInt64 value = state.load();
|
||||
if (!(value & writers) && state.compare_exchange_strong(value, value + 1))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
void FastSharedMutex::unlock_shared()
|
||||
{
|
||||
UInt64 value = state.fetch_sub(1) - 1;
|
||||
if (value == writers)
|
||||
wakeLowerOne(state); // Wake writer
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int THREAD_WAS_CANCELLED;
|
||||
}
|
||||
|
||||
void CancelToken::raise()
|
||||
{
|
||||
throw DB::Exception(ErrorCodes::THREAD_WAS_CANCELLED, "Thread was cancelled");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
282
src/Common/Threading.h
Normal file
282
src/Common/Threading.h
Normal file
@ -0,0 +1,282 @@
|
||||
#pragma once
|
||||
|
||||
#include <base/types.h>
|
||||
#include <base/defines.h>
|
||||
|
||||
#ifdef OS_LINUX /// Because of futex
|
||||
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
|
||||
|
||||
// TODO(serxa): for debug only, remove it
|
||||
#if 0
|
||||
#include <iostream>
|
||||
#include <base/getThreadId.h>
|
||||
#define DBG(...) std::cout << fmt::format("\033[01;3{}m[{}] {} {} {}\033[00m {}:{}\n", 1 + getThreadId() % 8, getThreadId(), reinterpret_cast<void*>(this), fmt::format(__VA_ARGS__), __PRETTY_FUNCTION__, __FILE__, __LINE__)
|
||||
#else
|
||||
#include <base/defines.h>
|
||||
#define DBG(...) UNUSED(__VA_ARGS__)
|
||||
#endif
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
// Scoped object, enabling thread cancellation (cannot be nested)
|
||||
struct Cancellable
|
||||
{
|
||||
Cancellable();
|
||||
~Cancellable();
|
||||
};
|
||||
|
||||
// Scoped object, disabling thread cancellation (cannot be nested; must be inside `Cancellable` region)
|
||||
struct NotCancellable
|
||||
{
|
||||
NotCancellable();
|
||||
~NotCancellable();
|
||||
};
|
||||
|
||||
// Responsible for synchronization needed to deliver thread cancellation signal.
|
||||
// Basic building block for cancallable synchronization primitives.
|
||||
// Allows to perform cancellable wait on memory addresses (think futex)
|
||||
class CancelToken
|
||||
{
|
||||
public:
|
||||
CancelToken();
|
||||
CancelToken(const CancelToken &) = delete;
|
||||
CancelToken(CancelToken &&) = delete;
|
||||
CancelToken & operator=(const CancelToken &) = delete;
|
||||
~CancelToken();
|
||||
|
||||
// Returns token for the current thread
|
||||
static CancelToken & local()
|
||||
{
|
||||
static thread_local CancelToken token;
|
||||
return token;
|
||||
}
|
||||
|
||||
// Cancellable wait on memory address (futex word).
|
||||
// Thread will do atomic compare-and-sleep `*address == value`. Waiting will continue until `notify_one()`
|
||||
// or `notify_all()` will be called with the same `address` or calling thread will be canceled using `signal()`.
|
||||
// Note that spurious wake-ups are also possible due to cancellation of other waiters on the same `address`.
|
||||
// WARNING: `address` must be 2-byte aligned and `value` highest bit must be zero.
|
||||
// Return value:
|
||||
// true - woken by either notify or spurious wakeup;
|
||||
// false - iff cancelation signal has been received.
|
||||
// Implementation details:
|
||||
// It registers `address` inside token's `state` to allow other threads to wake this thread and deliver cancellation signal.
|
||||
// Highest bit of `*address` is used for guarantied delivery of the signal, but is guaranteed to be zero on return due to cancellation.
|
||||
// Intented to be called only by thread associated with this token.
|
||||
bool wait(UInt32 * address, UInt32 value);
|
||||
|
||||
// Throws `DB::Exception` received from `signal()`. Call it if `wait()` returned false.
|
||||
// Intented to be called only by thread associated with this token.
|
||||
[[noreturn]] void raise();
|
||||
|
||||
// Regular wake by address (futex word). It does not interact with token in any way. We have it here to complement `wait()`.
|
||||
// Can be called from any thread.
|
||||
static void notifyOne(UInt32 * address);
|
||||
static void notifyAll(UInt32 * address);
|
||||
|
||||
// Send cancel signal to thread with specified `tid`.
|
||||
// If thread was waiting using `wait()` it will be woken up (unless cancellation is disabled).
|
||||
// Can be called from any thread.
|
||||
static void signal(UInt64 tid);
|
||||
static void signal(UInt64 tid, int code, const String & message);
|
||||
|
||||
// Flag used to deliver cancellation into memory address to wake a thread.
|
||||
// Note that most significat bit at `addresses` to be used with `wait()` is reserved.
|
||||
static constexpr UInt32 signaled = 1u << 31u;
|
||||
|
||||
private:
|
||||
friend struct Cancellable;
|
||||
friend struct NotCancellable;
|
||||
|
||||
// Restores initial state for token to be reused. See `Cancellable` struct.
|
||||
// Intented to be called only by thread associated with this token.
|
||||
void reset()
|
||||
{
|
||||
state.store(0);
|
||||
}
|
||||
|
||||
// Enable thread cancellation. See `NotCancellable` struct.
|
||||
// Intented to be called only by thread associated with this token.
|
||||
void enable()
|
||||
{
|
||||
chassert((state.load() & disabled) == disabled);
|
||||
state.fetch_and(~disabled);
|
||||
}
|
||||
|
||||
// Disable thread cancellation. See `NotCancellable` struct.
|
||||
// Intented to be called only by thread associated with this token.
|
||||
void disable()
|
||||
{
|
||||
chassert((state.load() & disabled) == 0);
|
||||
state.fetch_or(disabled);
|
||||
}
|
||||
|
||||
// Singleton. Maps thread IDs to tokens.
|
||||
struct Registry;
|
||||
friend struct Registry;
|
||||
struct Registry
|
||||
{
|
||||
Registry();
|
||||
|
||||
std::mutex mutex;
|
||||
std::unordered_map<UInt64, CancelToken*> threads; // By thread ID
|
||||
|
||||
void insert(CancelToken * token);
|
||||
void remove(CancelToken * token);
|
||||
void signal(UInt64 tid);
|
||||
void signal(UInt64 tid, int code, const String & message);
|
||||
|
||||
static Registry & instance();
|
||||
};
|
||||
|
||||
// Cancels this token and wakes thread if necessary.
|
||||
// Can be called from any thread.
|
||||
void signalImpl();
|
||||
void signalImpl(int code, const String & message);
|
||||
|
||||
// Lower bit: cancel signal received flag
|
||||
static constexpr UInt64 canceled = 1;
|
||||
|
||||
// Upper bits - possible values:
|
||||
// 1) all zeros: token is enabed, i.e. wait() call can return false, thread is not waiting on any address;
|
||||
// 2) all ones: token is disabled, i.e. wait() call cannot be cancelled;
|
||||
// 3) specific `address`: token is enabled and thread is currently waiting on this `address`.
|
||||
static constexpr UInt64 disabled = ~canceled;
|
||||
static_assert(sizeof(UInt32 *) == sizeof(UInt64)); // State must be able to hold an address
|
||||
|
||||
// All signal handling logic should be globally serialized using this mutex
|
||||
static std::mutex signal_mutex;
|
||||
|
||||
// Cancellation state
|
||||
alignas(64) std::atomic<UInt64> state;
|
||||
[[maybe_unused]] char padding[64 - sizeof(state)];
|
||||
|
||||
// Cancellation exception
|
||||
int exception_code;
|
||||
String exception_message;
|
||||
|
||||
// Token is permanently attached to a single thread. There is one-to-one mapping between threads and tokens.
|
||||
const UInt64 thread_id;
|
||||
};
|
||||
|
||||
class CancellableSharedMutex
|
||||
{
|
||||
public:
|
||||
CancellableSharedMutex();
|
||||
~CancellableSharedMutex() = default;
|
||||
CancellableSharedMutex(const CancellableSharedMutex &) = delete;
|
||||
CancellableSharedMutex & operator=(const CancellableSharedMutex &) = delete;
|
||||
|
||||
// Exclusive ownership
|
||||
void lock();
|
||||
bool try_lock();
|
||||
void unlock();
|
||||
|
||||
// Shared ownership
|
||||
void lock_shared();
|
||||
bool try_lock_shared();
|
||||
void unlock_shared();
|
||||
|
||||
private:
|
||||
// State 64-bits layout:
|
||||
// 1b - 31b - 1b - 31b
|
||||
// signaled - writers - signaled - readers
|
||||
// 63------------------------------------0
|
||||
// Two 32-bit words are used for cancellable waiting, so each has its own separate signaled bit
|
||||
static constexpr UInt64 readers = (1ull << 32ull) - 1ull - CancelToken::signaled;
|
||||
static constexpr UInt64 readers_signaled = CancelToken::signaled;
|
||||
static constexpr UInt64 writers = readers << 32ull;
|
||||
static constexpr UInt64 writers_signaled = readers_signaled << 32ull;
|
||||
|
||||
alignas(64) std::atomic<UInt64> state;
|
||||
std::atomic<UInt32> waiters;
|
||||
};
|
||||
|
||||
class FastSharedMutex
|
||||
{
|
||||
public:
|
||||
FastSharedMutex();
|
||||
~FastSharedMutex() = default;
|
||||
FastSharedMutex(const FastSharedMutex &) = delete;
|
||||
FastSharedMutex & operator=(const FastSharedMutex &) = delete;
|
||||
|
||||
// Exclusive ownership
|
||||
void lock();
|
||||
bool try_lock();
|
||||
void unlock();
|
||||
|
||||
// Shared ownership
|
||||
void lock_shared();
|
||||
bool try_lock_shared();
|
||||
void unlock_shared();
|
||||
|
||||
private:
|
||||
static constexpr UInt64 readers = (1ull << 32ull) - 1ull; // Lower 32 bits of state
|
||||
static constexpr UInt64 writers = ~readers; // Upper 32 bits of state
|
||||
|
||||
alignas(64) std::atomic<UInt64> state;
|
||||
std::atomic<UInt32> waiters;
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#include <shared_mutex>
|
||||
|
||||
// WARNING: We support cancellable synchronization primitives only on linux for now
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int THREAD_WAS_CANCELLED;
|
||||
}
|
||||
|
||||
struct Cancellable
|
||||
{
|
||||
Cancellable() = default;
|
||||
~Cancellable() = default;
|
||||
};
|
||||
|
||||
struct NotCancellable
|
||||
{
|
||||
NotCancellable() = default;
|
||||
~NotCancellable() = default;
|
||||
};
|
||||
|
||||
class CancelToken
|
||||
{
|
||||
public:
|
||||
CancelToken() = default;
|
||||
CancelToken(const CancelToken &) = delete;
|
||||
CancelToken(CancelToken &&) = delete;
|
||||
CancelToken & operator=(const CancelToken &) = delete;
|
||||
~CancelToken() = default;
|
||||
|
||||
static CancelToken & local()
|
||||
{
|
||||
static CancelToken token;
|
||||
return token;
|
||||
}
|
||||
|
||||
bool wait(UInt32 *, UInt32) { return true; }
|
||||
[[noreturn]] void raise();
|
||||
static void notifyOne(UInt32 *) {}
|
||||
static void notifyAll(UInt32 *) {}
|
||||
static void signal(UInt64) {}
|
||||
static void signal(UInt64, int, const String &) {}
|
||||
};
|
||||
|
||||
using CancellableSharedMutex = std::shared_mutex;
|
||||
using FastSharedMutex = std::shared_mutex;
|
||||
|
||||
}
|
||||
|
||||
#endif
|
369
src/Common/tests/gtest_threading.cpp
Normal file
369
src/Common/tests/gtest_threading.cpp
Normal file
@ -0,0 +1,369 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <thread>
|
||||
#include <condition_variable>
|
||||
#include <shared_mutex>
|
||||
#include <barrier>
|
||||
#include <atomic>
|
||||
|
||||
#include "Common/Exception.h"
|
||||
#include <Common/Threading.h>
|
||||
#include <Common/Stopwatch.h>
|
||||
|
||||
#include <base/demangle.h>
|
||||
#include <base/getThreadId.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int THREAD_WAS_CANCELLED;
|
||||
}
|
||||
}
|
||||
|
||||
struct NoCancel {};
|
||||
|
||||
// for all PerfTests
|
||||
static constexpr int requests = 512 * 1024;
|
||||
static constexpr int max_threads = 16;
|
||||
|
||||
template <class T, class Status = NoCancel>
|
||||
void TestSharedMutex()
|
||||
{
|
||||
// Test multiple readers can acquire lock
|
||||
for (int readers = 1; readers <= 128; readers *= 2)
|
||||
{
|
||||
T sm;
|
||||
std::atomic<int> test(0);
|
||||
std::barrier sync(readers + 1);
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(readers);
|
||||
auto reader = [&]
|
||||
{
|
||||
[[maybe_unused]] Status status;
|
||||
std::shared_lock lock(sm);
|
||||
test++;
|
||||
sync.arrive_and_wait();
|
||||
};
|
||||
|
||||
for (int i = 0; i < readers; i++)
|
||||
threads.emplace_back(reader);
|
||||
|
||||
{ // writer
|
||||
[[maybe_unused]] Status status;
|
||||
sync.arrive_and_wait(); // wait for all reader to acquire lock to avoid blocking them
|
||||
std::unique_lock lock(sm);
|
||||
test++;
|
||||
}
|
||||
|
||||
for (auto & thread : threads)
|
||||
thread.join();
|
||||
|
||||
ASSERT_EQ(test, readers + 1);
|
||||
}
|
||||
|
||||
// Test multiple writers cannot acquire lock simultaneously
|
||||
for (int writers = 1; writers <= 128; writers *= 2)
|
||||
{
|
||||
T sm;
|
||||
int test = 0;
|
||||
std::barrier sync(writers);
|
||||
std::vector<std::thread> threads;
|
||||
|
||||
threads.reserve(writers);
|
||||
auto writer = [&]
|
||||
{
|
||||
[[maybe_unused]] Status status;
|
||||
sync.arrive_and_wait();
|
||||
std::unique_lock lock(sm);
|
||||
test++;
|
||||
};
|
||||
|
||||
for (int i = 0; i < writers; i++)
|
||||
threads.emplace_back(writer);
|
||||
|
||||
for (auto & thread : threads)
|
||||
thread.join();
|
||||
|
||||
ASSERT_EQ(test, writers);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class Status = NoCancel>
|
||||
void TestSharedMutexCancelReader()
|
||||
{
|
||||
constexpr int readers = 8;
|
||||
constexpr int tasks_per_reader = 32;
|
||||
|
||||
T sm;
|
||||
std::atomic<int> successes(0);
|
||||
std::atomic<int> cancels(0);
|
||||
std::barrier sync(readers + 1);
|
||||
std::barrier cancel_sync(readers / 2 + 1);
|
||||
std::vector<std::thread> threads;
|
||||
|
||||
std::mutex m;
|
||||
std::vector<UInt64> tids_to_cancel;
|
||||
|
||||
threads.reserve(readers);
|
||||
auto reader = [&] (int reader_id)
|
||||
{
|
||||
if (reader_id % 2 == 0)
|
||||
{
|
||||
std::unique_lock lock(m);
|
||||
tids_to_cancel.emplace_back(getThreadId());
|
||||
}
|
||||
for (int task = 0; task < tasks_per_reader; task++) {
|
||||
try
|
||||
{
|
||||
[[maybe_unused]] Status status;
|
||||
sync.arrive_and_wait(); // (A) sync with writer
|
||||
sync.arrive_and_wait(); // (B) wait for writer to acquire unique_lock
|
||||
std::shared_lock lock(sm);
|
||||
successes++;
|
||||
}
|
||||
catch(DB::Exception & e)
|
||||
{
|
||||
ASSERT_EQ(e.code(), DB::ErrorCodes::THREAD_WAS_CANCELLED);
|
||||
ASSERT_EQ(e.message(), "test");
|
||||
cancels++;
|
||||
cancel_sync.arrive_and_wait(); // (C) sync with writer
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int reader_id = 0; reader_id < readers; reader_id++)
|
||||
threads.emplace_back(reader, reader_id);
|
||||
|
||||
{ // writer
|
||||
[[maybe_unused]] Status status;
|
||||
for (int task = 0; task < tasks_per_reader; task++) {
|
||||
sync.arrive_and_wait(); // (A) wait for readers to finish previous task
|
||||
ASSERT_EQ(cancels + successes, task * readers);
|
||||
ASSERT_EQ(cancels, task * readers / 2);
|
||||
ASSERT_EQ(successes, task * readers / 2);
|
||||
std::unique_lock lock(sm);
|
||||
sync.arrive_and_wait(); // (B) sync with readers
|
||||
//std::unique_lock lock(m); // not needed, already synced using barrier
|
||||
for (UInt64 tid : tids_to_cancel)
|
||||
DB::CancelToken::signal(tid, DB::ErrorCodes::THREAD_WAS_CANCELLED, "test");
|
||||
|
||||
// This sync is crutial. It is needed to hold `lock` long enough.
|
||||
// It guarantees that every cancelled thread will find `sm` blocked by writer, and thus will begin to wait.
|
||||
// Wait() call is required for cancellation. Otherwise, fastpath acquire w/o wait will not generate exception.
|
||||
// And this is the desired behaviour.
|
||||
cancel_sync.arrive_and_wait(); // (C) wait for cancellation to finish, before unlock.
|
||||
}
|
||||
}
|
||||
|
||||
for (auto & thread : threads)
|
||||
thread.join();
|
||||
|
||||
ASSERT_EQ(successes, tasks_per_reader * readers / 2);
|
||||
ASSERT_EQ(cancels, tasks_per_reader * readers / 2);
|
||||
}
|
||||
|
||||
template <class T, class Status = NoCancel>
|
||||
void TestSharedMutexCancelWriter()
|
||||
{
|
||||
constexpr int writers = 8;
|
||||
constexpr int tasks_per_writer = 32;
|
||||
|
||||
T sm;
|
||||
std::atomic<int> successes(0);
|
||||
std::atomic<int> cancels(0);
|
||||
std::barrier sync(writers);
|
||||
std::vector<std::thread> threads;
|
||||
|
||||
std::mutex m;
|
||||
std::vector<UInt64> all_tids;
|
||||
|
||||
threads.reserve(writers);
|
||||
auto writer = [&]
|
||||
{
|
||||
{
|
||||
std::unique_lock lock(m);
|
||||
all_tids.emplace_back(getThreadId());
|
||||
}
|
||||
for (int task = 0; task < tasks_per_writer; task++) {
|
||||
try
|
||||
{
|
||||
[[maybe_unused]] Status status;
|
||||
sync.arrive_and_wait(); // (A) sync all threads before race to acquire the lock
|
||||
std::unique_lock lock(sm);
|
||||
successes++;
|
||||
// Thread that managed to acquire the lock cancels all other waiting writers
|
||||
//std::unique_lock lock(m); // not needed, already synced using barrier
|
||||
for (UInt64 tid : all_tids)
|
||||
{
|
||||
if (tid != getThreadId())
|
||||
DB::CancelToken::signal(tid, DB::ErrorCodes::THREAD_WAS_CANCELLED, "test");
|
||||
}
|
||||
|
||||
// This sync is crutial. It is needed to hold `lock` long enough.
|
||||
// It guarantees that every cancelled thread will find `sm` blocked, and thus will begin to wait.
|
||||
// Wait() call is required for cancellation. Otherwise, fastpath acquire w/o wait will not generate exception.
|
||||
// And this is the desired behaviour.
|
||||
sync.arrive_and_wait(); // (B) wait for cancellation to finish, before unlock.
|
||||
}
|
||||
catch(DB::Exception & e)
|
||||
{
|
||||
ASSERT_EQ(e.code(), DB::ErrorCodes::THREAD_WAS_CANCELLED);
|
||||
ASSERT_EQ(e.message(), "test");
|
||||
cancels++;
|
||||
sync.arrive_and_wait(); // (B) sync with race winner
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int writer_id = 0; writer_id < writers; writer_id++)
|
||||
threads.emplace_back(writer);
|
||||
|
||||
for (auto & thread : threads)
|
||||
thread.join();
|
||||
|
||||
ASSERT_EQ(successes, tasks_per_writer);
|
||||
ASSERT_EQ(cancels, tasks_per_writer * (writers - 1));
|
||||
}
|
||||
|
||||
template <class T, class Status = NoCancel>
|
||||
void PerfTestSharedMutexReadersOnly()
|
||||
{
|
||||
std::cout << "*** " << demangle(typeid(T).name()) << "/" << demangle(typeid(Status).name()) << " ***" << std::endl;
|
||||
|
||||
for (int thrs = 1; thrs <= max_threads; thrs *= 2)
|
||||
{
|
||||
T sm;
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(thrs);
|
||||
auto reader = [&]
|
||||
{
|
||||
[[maybe_unused]] Status status;
|
||||
for (int request = requests / thrs; request; request--)
|
||||
{
|
||||
std::shared_lock lock(sm);
|
||||
}
|
||||
};
|
||||
|
||||
Stopwatch watch;
|
||||
for (int i = 0; i < thrs; i++)
|
||||
threads.emplace_back(reader);
|
||||
|
||||
for (auto & thread : threads)
|
||||
thread.join();
|
||||
|
||||
double ns = watch.elapsedNanoseconds();
|
||||
std::cout << "thrs = " << thrs << ":\t" << ns / requests << " ns\t" << requests * 1e9 / ns << " rps" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class Status = NoCancel>
|
||||
void PerfTestSharedMutexWritersOnly()
|
||||
{
|
||||
std::cout << "*** " << demangle(typeid(T).name()) << "/" << demangle(typeid(Status).name()) << " ***" << std::endl;
|
||||
|
||||
for (int thrs = 1; thrs <= max_threads; thrs *= 2)
|
||||
{
|
||||
int counter = 0;
|
||||
T sm;
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(thrs);
|
||||
auto writer = [&]
|
||||
{
|
||||
[[maybe_unused]] Status status;
|
||||
for (int request = requests / thrs; request; request--)
|
||||
{
|
||||
std::unique_lock lock(sm);
|
||||
ASSERT_TRUE(counter % 2 == 0);
|
||||
counter++;
|
||||
std::atomic_signal_fence(std::memory_order::seq_cst); // force complier to generate two separate increment instructions
|
||||
counter++;
|
||||
}
|
||||
};
|
||||
|
||||
Stopwatch watch;
|
||||
for (int i = 0; i < thrs; i++)
|
||||
threads.emplace_back(writer);
|
||||
|
||||
for (auto & thread : threads)
|
||||
thread.join();
|
||||
|
||||
ASSERT_EQ(counter, requests * 2);
|
||||
|
||||
double ns = watch.elapsedNanoseconds();
|
||||
std::cout << "thrs = " << thrs << ":\t" << ns / requests << " ns\t" << requests * 1e9 / ns << " rps" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T, class Status = NoCancel>
|
||||
void PerfTestSharedMutexRW()
|
||||
{
|
||||
std::cout << "*** " << demangle(typeid(T).name()) << "/" << demangle(typeid(Status).name()) << " ***" << std::endl;
|
||||
|
||||
for (int thrs = 1; thrs <= max_threads; thrs *= 2)
|
||||
{
|
||||
int counter = 0;
|
||||
T sm;
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(thrs);
|
||||
auto reader = [&]
|
||||
{
|
||||
[[maybe_unused]] Status status;
|
||||
for (int request = requests / thrs / 2; request; request--)
|
||||
{
|
||||
{
|
||||
std::shared_lock lock(sm);
|
||||
ASSERT_TRUE(counter % 2 == 0);
|
||||
}
|
||||
{
|
||||
std::unique_lock lock(sm);
|
||||
ASSERT_TRUE(counter % 2 == 0);
|
||||
counter++;
|
||||
std::atomic_signal_fence(std::memory_order::seq_cst); // force complier to generate two separate increment instructions
|
||||
counter++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Stopwatch watch;
|
||||
for (int i = 0; i < thrs; i++)
|
||||
threads.emplace_back(reader);
|
||||
|
||||
for (auto & thread : threads)
|
||||
thread.join();
|
||||
|
||||
ASSERT_EQ(counter, requests);
|
||||
|
||||
double ns = watch.elapsedNanoseconds();
|
||||
std::cout << "thrs = " << thrs << ":\t" << ns / requests << " ns\t" << requests * 1e9 / ns << " rps" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Threading, SharedMutexSmokeCancellableEnabled) { TestSharedMutex<DB::CancellableSharedMutex, DB::Cancellable>(); }
|
||||
TEST(Threading, SharedMutexSmokeCancellableDisabled) { TestSharedMutex<DB::CancellableSharedMutex>(); }
|
||||
TEST(Threading, SharedMutexSmokeFast) { TestSharedMutex<DB::FastSharedMutex>(); }
|
||||
TEST(Threading, SharedMutexSmokeStd) { TestSharedMutex<std::shared_mutex>(); }
|
||||
|
||||
TEST(Threading, PerfTestSharedMutexReadersOnlyCancellableEnabled) { PerfTestSharedMutexReadersOnly<DB::CancellableSharedMutex, DB::Cancellable>(); }
|
||||
TEST(Threading, PerfTestSharedMutexReadersOnlyCancellableDisabled) { PerfTestSharedMutexReadersOnly<DB::CancellableSharedMutex>(); }
|
||||
TEST(Threading, PerfTestSharedMutexReadersOnlyFast) { PerfTestSharedMutexReadersOnly<DB::FastSharedMutex>(); }
|
||||
TEST(Threading, PerfTestSharedMutexReadersOnlyStd) { PerfTestSharedMutexReadersOnly<std::shared_mutex>(); }
|
||||
|
||||
TEST(Threading, PerfTestSharedMutexWritersOnlyCancellableEnabled) { PerfTestSharedMutexWritersOnly<DB::CancellableSharedMutex, DB::Cancellable>(); }
|
||||
TEST(Threading, PerfTestSharedMutexWritersOnlyCancellableDisabled) { PerfTestSharedMutexWritersOnly<DB::CancellableSharedMutex>(); }
|
||||
TEST(Threading, PerfTestSharedMutexWritersOnlyFast) { PerfTestSharedMutexWritersOnly<DB::FastSharedMutex>(); }
|
||||
TEST(Threading, PerfTestSharedMutexWritersOnlyStd) { PerfTestSharedMutexWritersOnly<std::shared_mutex>(); }
|
||||
|
||||
TEST(Threading, PerfTestSharedMutexRWCancellableEnabled) { PerfTestSharedMutexRW<DB::CancellableSharedMutex, DB::Cancellable>(); }
|
||||
TEST(Threading, PerfTestSharedMutexRWCancellableDisabled) { PerfTestSharedMutexRW<DB::CancellableSharedMutex>(); }
|
||||
TEST(Threading, PerfTestSharedMutexRWFast) { PerfTestSharedMutexRW<DB::FastSharedMutex>(); }
|
||||
TEST(Threading, PerfTestSharedMutexRWStd) { PerfTestSharedMutexRW<std::shared_mutex>(); }
|
||||
|
||||
#ifdef OS_LINUX /// These tests require cancellability
|
||||
|
||||
TEST(Threading, SharedMutexCancelReaderCancellableEnabled) { TestSharedMutexCancelReader<DB::CancellableSharedMutex, DB::Cancellable>(); }
|
||||
TEST(Threading, SharedMutexCancelWriterCancellableEnabled) { TestSharedMutexCancelWriter<DB::CancellableSharedMutex, DB::Cancellable>(); }
|
||||
|
||||
#endif
|
Loading…
Reference in New Issue
Block a user