Merge pull request #33098 from bigo-sg/lrucache

Add alternative LRUCache version
This commit is contained in:
Kseniia Sumarokova 2021-12-30 23:10:28 +03:00 committed by GitHub
commit a703bcb0c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 753 additions and 1 deletions

View File

@ -64,6 +64,18 @@ public:
setImpl(key, mapped, lock);
}
void remove(const Key & key)
{
std::lock_guard lock(mutex);
auto it = cells.find(key);
if (it == cells.end())
return;
auto & cell = it->second;
current_size -= cell.size;
queue.erase(cell.queue_iterator);
cells.erase(it);
}
/// If the value for the key is in the cache, returns it. If it is not, calls load_func() to
/// produce it, saves the result in the cache and returns it.
/// Only one of several concurrent threads calling getOrSet() will call load_func(),

View File

@ -0,0 +1,374 @@
#pragma once
#include <atomic>
#include <chrono>
#include <list>
#include <memory>
#include <mutex>
#include <unordered_map>
#include <unordered_set>
#include <base/logger_useful.h>
namespace DB
{
template <typename T>
struct TrivailLRUResourceCacheWeightFunction
{
size_t operator()(const T &) const { return 1; }
};
/*
* A resource cache with key index. There is only one instance for every key which is not like the normal resource pool.
* Resource cache has max weight capacity and keys size limitation. If the limitation is exceeded, keys would be evicted
* by LRU policy.
*
* acquire and release must be used in pair.
*/
template <
typename TKey,
typename TMapped,
typename WeightFunction = TrivailLRUResourceCacheWeightFunction<TMapped>,
typename HashFunction = std::hash<TKey>>
class LRUResourceCache
{
public:
using Key = TKey;
using Mapped = TMapped;
using MappedPtr = std::shared_ptr<Mapped>;
class MappedHolder
{
public:
~MappedHolder() { cache->release(key); }
Mapped & value() { return *val.get(); }
MappedHolder(LRUResourceCache * cache_, const Key & key_, MappedPtr value_) : cache(cache_), key(key_), val(value_) { }
protected:
LRUResourceCache * cache;
Key key;
MappedPtr val;
};
using MappedHolderPtr = std::unique_ptr<MappedHolder>;
// use get() or getOrSet() to access the elements
MappedHolderPtr get(const Key & key)
{
auto mapped_ptr = getImpl(key);
if (!mapped_ptr)
return nullptr;
return std::make_unique<MappedHolder>(this, key, mapped_ptr);
}
template <typename LoadFunc>
MappedHolderPtr getOrSet(const Key & key, LoadFunc && load_func)
{
auto mapped_ptr = getImpl(key, load_func);
if (!mapped_ptr)
return nullptr;
return std::make_unique<MappedHolder>(this, key, mapped_ptr);
}
// If the key's reference_count = 0, delete it immediately. otherwise, mark it expired, and delete in release
void tryRemove(const Key & key)
{
std::lock_guard lock(mutex);
auto it = cells.find(key);
if (it == cells.end())
return;
auto & cell = it->second;
if (cell.reference_count == 0)
{
queue.erase(cell.queue_iterator);
current_weight -= cell.weight;
cells.erase(it);
}
else
cell.expired = true;
}
LRUResourceCache(size_t max_weight_, size_t max_element_size_ = 0) : max_weight(max_weight_), max_element_size(max_element_size_) { }
~LRUResourceCache() = default;
size_t weight()
{
std::lock_guard lock(mutex);
return current_weight;
}
size_t size()
{
std::lock_guard lock(mutex);
return cells.size();
}
void getStats(size_t & out_hits, size_t & out_misses, size_t & out_evict_count) const
{
out_hits = hits;
out_misses = misses;
out_evict_count = evict_count;
}
private:
mutable std::mutex mutex;
using LRUQueue = std::list<Key>;
using LRUQueueIterator = typename LRUQueue::iterator;
struct Cell
{
MappedPtr value;
size_t weight = 0;
LRUQueueIterator queue_iterator;
size_t reference_count = 0;
bool expired = false;
};
using Cells = std::unordered_map<Key, Cell, HashFunction>;
Cells cells;
LRUQueue queue;
size_t current_weight = 0;
size_t max_weight = 0;
size_t max_element_size = 0;
/// Represents pending insertion attempt.
struct InsertToken
{
explicit InsertToken(LRUResourceCache & cache_) : cache(cache_) { }
std::mutex mutex;
bool cleaned_up = false; /// Protected by the token mutex
MappedPtr value; /// Protected by the token mutex
LRUResourceCache & cache;
size_t refcount = 0; /// Protected by the cache mutex
};
using InsertTokenById = std::unordered_map<Key, std::shared_ptr<InsertToken>, HashFunction>;
/// This class is responsible for removing used insert tokens from the insert_tokens map.
/// Among several concurrent threads the first successful one is responsible for removal. But if they all
/// fail, then the last one is responsible.
struct InsertTokenHolder
{
const Key * key = nullptr;
std::shared_ptr<InsertToken> token;
bool cleaned_up = false;
InsertTokenHolder() = default;
void
acquire(const Key * key_, const std::shared_ptr<InsertToken> & token_, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock)
{
key = key_;
token = token_;
++token->refcount;
}
void cleanup([[maybe_unused]] std::lock_guard<std::mutex> & token_lock, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock)
{
token->cache.insert_tokens.erase(*key);
token->cleaned_up = true;
cleaned_up = true;
}
~InsertTokenHolder()
{
if (!token)
return;
if (cleaned_up)
return;
std::lock_guard token_lock(token->mutex);
if (token->cleaned_up)
return;
std::lock_guard cache_lock(token->cache.mutex);
--token->refcount;
if (token->refcount == 0)
cleanup(token_lock, cache_lock);
}
};
friend struct InsertTokenHolder;
InsertTokenById insert_tokens;
WeightFunction weight_function;
std::atomic<size_t> hits{0};
std::atomic<size_t> misses{0};
std::atomic<size_t> evict_count{0};
// - load_func : when key is not exists in cache, load_func is called to generate a new value
// - return: is null when there is no more space for the new value or the old value is in used.
template <typename LoadFunc>
MappedPtr getImpl(const Key & key, LoadFunc && load_func)
{
InsertTokenHolder token_holder;
{
std::lock_guard lock(mutex);
auto it = cells.find(key);
if (it != cells.end() && !it->second.expired)
{
if (!it->second.expired)
{
hits++;
it->second.reference_count += 1;
queue.splice(queue.end(), queue, it->second.queue_iterator);
return it->second.value;
}
else if (it->second.reference_count > 0)
return nullptr;
else
{
// should not reach here
LOG_ERROR(&Poco::Logger::get("LRUResourceCache"), "element is in invalid status.");
abort();
}
}
misses++;
auto & token = insert_tokens[key];
if (!token)
token = std::make_shared<InsertToken>(*this);
token_holder.acquire(&key, token, lock);
}
auto * token = token_holder.token.get();
std::lock_guard token_lock(token->mutex);
token_holder.cleaned_up = token->cleaned_up;
if (!token->value)
token->value = load_func();
std::lock_guard lock(mutex);
auto token_it = insert_tokens.find(key);
Cell * cell_ptr = nullptr;
if (token_it != insert_tokens.end() && token_it->second.get() == token)
{
cell_ptr = set(key, token->value);
}
else
{
auto cell_it = cells.find(key);
if (cell_it != cells.end() && !cell_it->second.expired)
{
cell_ptr = &cell_it->second;
}
}
if (!token->cleaned_up)
token_holder.cleanup(token_lock, lock);
if (cell_ptr)
{
queue.splice(queue.end(), queue, cell_ptr->queue_iterator);
cell_ptr->reference_count++;
return cell_ptr->value;
}
return nullptr;
}
MappedPtr getImpl(const Key & key)
{
std::lock_guard lock(mutex);
auto it = cells.find(key);
if (it == cells.end() || it->second.expired)
{
misses++;
return nullptr;
}
hits++;
it->second.reference_count += 1;
queue.splice(queue.end(), queue, it->second.queue_iterator);
return it->second.value;
}
// mark a reference is released
void release(const Key & key)
{
std::lock_guard lock(mutex);
auto it = cells.find(key);
if (it == cells.end() || it->second.reference_count == 0)
{
LOG_ERROR(&Poco::Logger::get("LRUResourceCache"), "try to release an invalid element");
abort();
}
auto & cell = it->second;
cell.reference_count -= 1;
if (cell.expired && cell.reference_count == 0)
{
queue.erase(cell.queue_iterator);
current_weight -= cell.weight;
cells.erase(it);
}
}
InsertToken * acquireInsertToken(const Key & key)
{
auto & token = insert_tokens[key];
token.reference_count += 1;
return &token;
}
void releaseInsertToken(const Key & key)
{
auto it = insert_tokens.find(key);
if (it != insert_tokens.end())
{
it->second.reference_count -= 1;
if (it->second.reference_count == 0)
insert_tokens.erase(it);
}
}
// key mustn't be in the cache
Cell * set(const Key & insert_key, MappedPtr value)
{
auto weight = value ? weight_function(*value) : 0;
auto queue_size = cells.size() + 1;
auto loss_weight = 0;
auto is_overflow = [&] {
return current_weight + weight - loss_weight > max_weight || (max_element_size != 0 && queue_size > max_element_size);
};
auto key_it = queue.begin();
std::unordered_set<Key, HashFunction> to_release_keys;
while (is_overflow() && queue_size > 1 && key_it != queue.end())
{
const Key & key = *key_it;
auto cell_it = cells.find(key);
if (cell_it == cells.end())
{
LOG_ERROR(&Poco::Logger::get("LRUResourceCache"), "LRUResourceCache became inconsistent. There must be a bug in it.");
abort();
}
auto & cell = cell_it->second;
if (cell.reference_count == 0)
{
loss_weight += cell.weight;
queue_size -= 1;
to_release_keys.insert(key);
}
key_it++;
}
if (is_overflow())
return nullptr;
if (loss_weight > current_weight + weight)
{
LOG_ERROR(&Poco::Logger::get("LRUResourceCache"), "LRUResourceCache became inconsistent. There must be a bug in it.");
abort();
}
for (auto & key : to_release_keys)
{
auto & cell = cells[key];
queue.erase(cell.queue_iterator);
cells.erase(key);
evict_count++;
}
current_weight = current_weight + weight - loss_weight;
auto & new_cell = cells[insert_key];
new_cell.value = value;
new_cell.weight = weight;
new_cell.queue_iterator = queue.insert(queue.end(), insert_key);
return &new_cell;
}
};
}

View File

@ -0,0 +1,97 @@
#include <iomanip>
#include <iostream>
#include <gtest/gtest.h>
#include <Common/LRUCache.h>
TEST(LRUCache, set)
{
using SimpleLRUCache = DB::LRUCache<int, int>;
auto lru_cache = SimpleLRUCache(10, 10);
lru_cache.set(1, std::make_shared<int>(2));
lru_cache.set(2, std::make_shared<int>(3));
auto w = lru_cache.weight();
auto n = lru_cache.count();
ASSERT_EQ(w, 2);
ASSERT_EQ(n, 2);
}
TEST(LRUCache, update)
{
using SimpleLRUCache = DB::LRUCache<int, int>;
auto lru_cache = SimpleLRUCache(10, 10);
lru_cache.set(1, std::make_shared<int>(2));
lru_cache.set(1, std::make_shared<int>(3));
auto val = lru_cache.get(1);
ASSERT_TRUE(val != nullptr);
ASSERT_TRUE(*val == 3);
}
TEST(LRUCache, get)
{
using SimpleLRUCache = DB::LRUCache<int, int>;
auto lru_cache = SimpleLRUCache(10, 10);
lru_cache.set(1, std::make_shared<int>(2));
lru_cache.set(2, std::make_shared<int>(3));
SimpleLRUCache::MappedPtr value = lru_cache.get(1);
ASSERT_TRUE(value != nullptr);
ASSERT_EQ(*value, 2);
value = lru_cache.get(2);
ASSERT_TRUE(value != nullptr);
ASSERT_EQ(*value, 3);
}
struct ValueWeight
{
size_t operator()(const size_t & x) const { return x; }
};
TEST(LRUCache, evictOnSize)
{
using SimpleLRUCache = DB::LRUCache<int, size_t>;
auto lru_cache = SimpleLRUCache(20, 3);
lru_cache.set(1, std::make_shared<size_t>(2));
lru_cache.set(2, std::make_shared<size_t>(3));
lru_cache.set(3, std::make_shared<size_t>(4));
lru_cache.set(4, std::make_shared<size_t>(5));
auto n = lru_cache.count();
ASSERT_EQ(n, 3);
auto value = lru_cache.get(1);
ASSERT_TRUE(value == nullptr);
}
TEST(LRUCache, evictOnWeight)
{
using SimpleLRUCache = DB::LRUCache<int, size_t, std::hash<int>, ValueWeight>;
auto lru_cache = SimpleLRUCache(10, 10);
lru_cache.set(1, std::make_shared<size_t>(2));
lru_cache.set(2, std::make_shared<size_t>(3));
lru_cache.set(3, std::make_shared<size_t>(4));
lru_cache.set(4, std::make_shared<size_t>(5));
auto n = lru_cache.count();
ASSERT_EQ(n, 2);
auto w = lru_cache.weight();
ASSERT_EQ(w, 9);
auto value = lru_cache.get(1);
ASSERT_TRUE(value == nullptr);
value = lru_cache.get(2);
ASSERT_TRUE(value == nullptr);
}
TEST(LRUCache, getOrSet)
{
using SimpleLRUCache = DB::LRUCache<int, size_t, std::hash<int>, ValueWeight>;
auto lru_cache = SimpleLRUCache(10, 10);
size_t x = 10;
auto load_func = [&] { return std::make_shared<size_t>(x); };
auto [value, loaded] = lru_cache.getOrSet(1, load_func);
ASSERT_TRUE(value != nullptr);
ASSERT_TRUE(*value == 10);
}

View File

@ -0,0 +1,270 @@
#include <iomanip>
#include <iostream>
#include <gtest/gtest.h>
#include <Common/LRUResourceCache.h>
TEST(LRUResourceCache, get)
{
using MyCache = DB::LRUResourceCache<int, int>;
auto mcache = MyCache(10, 10);
int x = 10;
auto load_int = [&] { return std::make_shared<int>(x); };
auto holder1 = mcache.getOrSet(1, load_int);
x = 11;
auto holder2 = mcache.getOrSet(2, load_int);
ASSERT_TRUE(holder2 != nullptr);
ASSERT_TRUE(holder2->value() == 11);
auto holder3 = mcache.get(1);
ASSERT_TRUE(holder3 != nullptr);
ASSERT_TRUE(holder3->value() == 10);
}
TEST(LRUResourceCache, remove)
{
using MyCache = DB::LRUResourceCache<int, int>;
auto mcache = MyCache(10, 10);
int x = 10;
auto load_int = [&] { return std::make_shared<int>(x); };
auto holder0 = mcache.getOrSet(1, load_int);
auto holder1 = mcache.getOrSet(1, load_int);
mcache.tryRemove(1);
holder0 = mcache.get(1);
ASSERT_TRUE(holder0 == nullptr);
auto n = mcache.size();
ASSERT_TRUE(n == 1);
holder0.reset();
holder1.reset();
n = mcache.size();
ASSERT_TRUE(n == 0);
}
struct MyWeight
{
size_t operator()(const int & x) const { return static_cast<size_t>(x); }
};
TEST(LRUResourceCache, evictOnWweight)
{
using MyCache = DB::LRUResourceCache<int, int, MyWeight>;
auto mcache = MyCache(5, 10);
int x = 2;
auto load_int = [&] { return std::make_shared<int>(x); };
auto holder1 = mcache.getOrSet(1, load_int);
holder1.reset();
auto holder2 = mcache.getOrSet(2, load_int);
holder2.reset();
x = 3;
auto holder3 = mcache.getOrSet(3, load_int);
ASSERT_TRUE(holder3 != nullptr);
auto w = mcache.weight();
ASSERT_EQ(w, 5);
auto n = mcache.size();
ASSERT_EQ(n, 2);
holder1 = mcache.get(1);
ASSERT_TRUE(holder1 == nullptr);
holder2 = mcache.get(2);
ASSERT_TRUE(holder2 != nullptr);
holder3 = mcache.get(3);
ASSERT_TRUE(holder3 != nullptr);
}
TEST(LRUResourceCache, evictOnWeightV2)
{
using MyCache = DB::LRUResourceCache<int, int, MyWeight>;
auto mcache = MyCache(5, 10);
int x = 2;
auto load_int = [&] { return std::make_shared<int>(x); };
auto holder1 = mcache.getOrSet(1, load_int);
holder1.reset();
auto holder2 = mcache.getOrSet(2, load_int);
holder2.reset();
holder1 = mcache.get(1);
holder1.reset();
x = 3;
auto holder3 = mcache.getOrSet(3, load_int);
ASSERT_TRUE(holder3 != nullptr);
auto w = mcache.weight();
ASSERT_EQ(w, 5);
auto n = mcache.size();
ASSERT_EQ(n, 2);
holder1 = mcache.get(1);
ASSERT_TRUE(holder1 != nullptr);
holder2 = mcache.get(2);
ASSERT_TRUE(holder2 == nullptr);
holder3 = mcache.get(3);
ASSERT_TRUE(holder3 != nullptr);
}
TEST(LRUResourceCache, evictOnWeightV3)
{
using MyCache = DB::LRUResourceCache<int, int, MyWeight>;
auto mcache = MyCache(5, 10);
int x = 2;
auto load_int = [&] { return std::make_shared<int>(x); };
auto holder1 = mcache.getOrSet(1, load_int);
holder1.reset();
auto holder2 = mcache.getOrSet(2, load_int);
holder2.reset();
holder1 = mcache.getOrSet(1, load_int);
holder1.reset();
x = 3;
auto holder3 = mcache.getOrSet(3, load_int);
ASSERT_TRUE(holder3 != nullptr);
auto w = mcache.weight();
ASSERT_EQ(w, 5);
auto n = mcache.size();
ASSERT_EQ(n, 2);
holder1 = mcache.get(1);
ASSERT_TRUE(holder1 != nullptr);
holder2 = mcache.get(2);
ASSERT_TRUE(holder2 == nullptr);
holder3 = mcache.get(3);
ASSERT_TRUE(holder3 != nullptr);
}
TEST(LRUResourceCache, evictOnSize)
{
using MyCache = DB::LRUResourceCache<int, int>;
auto mcache = MyCache(5, 2);
int x = 2;
auto load_int = [&] { return std::make_shared<int>(x); };
auto holder1 = mcache.getOrSet(1, load_int);
holder1.reset();
auto holder2 = mcache.getOrSet(2, load_int);
holder2.reset();
x = 3;
auto holder3 = mcache.getOrSet(3, load_int);
ASSERT_TRUE(holder3 != nullptr);
auto n = mcache.size();
ASSERT_EQ(n, 2);
auto w = mcache.weight();
ASSERT_EQ(w, 2);
holder1 = mcache.get(1);
ASSERT_TRUE(holder1 == nullptr);
holder2 = mcache.get(2);
ASSERT_TRUE(holder2 != nullptr);
holder3 = mcache.get(3);
ASSERT_TRUE(holder3 != nullptr);
}
TEST(LRUResourceCache, notEvictUsedElement)
{
using MyCache = DB::LRUResourceCache<int, int, MyWeight>;
auto mcache = MyCache(7, 10);
int x = 2;
auto load_int = [&] { return std::make_shared<int>(x); };
auto holder1 = mcache.getOrSet(1, load_int);
auto holder2 = mcache.getOrSet(2, load_int);
holder2.reset();
auto holder3 = mcache.getOrSet(3, load_int);
holder3.reset();
x = 3;
auto holder4 = mcache.getOrSet(4, load_int);
ASSERT_TRUE(holder4 != nullptr);
auto n = mcache.size();
ASSERT_EQ(n, 3);
auto w = mcache.weight();
ASSERT_EQ(w, 7);
holder1 = mcache.get(1);
ASSERT_TRUE(holder1 != nullptr);
holder2 = mcache.get(2);
ASSERT_TRUE(holder2 == nullptr);
holder3 = mcache.get(3);
ASSERT_TRUE(holder3 != nullptr);
holder4 = mcache.get(4);
ASSERT_TRUE(holder4 != nullptr);
}
TEST(LRUResourceCache, getFail)
{
using MyCache = DB::LRUResourceCache<int, int, MyWeight>;
auto mcache = MyCache(5, 10);
int x = 2;
auto load_int = [&] { return std::make_shared<int>(x); };
auto holder1 = mcache.getOrSet(1, load_int);
auto holder2 = mcache.getOrSet(2, load_int);
auto holder3 = mcache.getOrSet(3, load_int);
ASSERT_TRUE(holder3 == nullptr);
auto n = mcache.size();
ASSERT_EQ(n, 2);
auto w = mcache.weight();
ASSERT_EQ(w, 4);
holder1 = mcache.get(1);
ASSERT_TRUE(holder1 != nullptr);
holder2 = mcache.get(2);
ASSERT_TRUE(holder2 != nullptr);
holder3 = mcache.get(3);
ASSERT_TRUE(holder3 == nullptr);
}
TEST(LRUResourceCache, dupGet)
{
using MyCache = DB::LRUResourceCache<int, int, MyWeight>;
auto mcache = MyCache(20, 10);
int x = 2;
auto load_int = [&] { return std::make_shared<int>(x); };
auto holder1 = mcache.getOrSet(1, load_int);
holder1.reset();
x = 11;
holder1 = mcache.getOrSet(1, load_int);
ASSERT_TRUE(holder1 != nullptr);
auto n = mcache.size();
ASSERT_EQ(n, 1);
auto w = mcache.weight();
ASSERT_EQ(w, 2);
holder1 = mcache.get(1);
ASSERT_TRUE(holder1 != nullptr);
ASSERT_TRUE(holder1->value() == 2);
}
TEST(LRUResourceCache, reGet)
{
using MyCache = DB::LRUResourceCache<int, int, MyWeight>;
auto mcache = MyCache(20, 10);
int x = 2;
auto load_int = [&] { return std::make_shared<int>(x); };
auto holder1 = mcache.getOrSet(1, load_int);
mcache.tryRemove(1);
x = 11;
holder1.reset();
holder1 = mcache.getOrSet(1, load_int);
ASSERT_TRUE(holder1 != nullptr);
auto n = mcache.size();
ASSERT_EQ(n, 1);
auto w = mcache.weight();
ASSERT_EQ(w, 11);
holder1 = mcache.get(1);
ASSERT_TRUE(holder1 != nullptr);
ASSERT_TRUE(holder1->value() == 11);
}

View File

@ -361,7 +361,6 @@ void Aggregator::compileAggregateFunctionsIfNeeded()
auto compiled_aggregate_functions = compileAggregateFunctions(getJITInstance(), functions_to_compile, functions_description);
return std::make_shared<CompiledAggregateFunctionsHolder>(std::move(compiled_aggregate_functions));
});
compiled_aggregate_functions_holder = std::static_pointer_cast<CompiledAggregateFunctionsHolder>(compiled_function_cache_entry);
}
else