Add tests for SLRU cache

This commit is contained in:
alexX512 2022-03-10 20:03:05 +03:00
parent d7bdf9e601
commit a0d2da7261
2 changed files with 140 additions and 8 deletions

View File

@ -1,3 +1,5 @@
#pragma once
#include <Common/LRUCache.h>
namespace DB
@ -21,11 +23,12 @@ public:
using Mapped = TMapped;
using MappedPtr = std::shared_ptr<Mapped>;
using Base = LRUCache<Key, Mapped, HashFunction, WeightFunction>;
using Base::mutex;
SLRUCache(size_t max_protected_size_, size_t max_size_)
: max_protected_size(std::max(static_cast<size_t>(1), max_protected_size_))
: Base(0)
, max_protected_size(std::max(static_cast<size_t>(1), max_protected_size_))
, max_size(std::max(max_protected_size + 1, max_size_))
, Base(max_size)
{}
void remove(const Key & key)
@ -53,6 +56,11 @@ public:
return cells.size();
}
size_t maxSize() const
{
return max_size;
}
size_t maxProtectedSize() const
{
return max_protected_size;
@ -94,10 +102,12 @@ private:
SLRUQueue probationary_queue;
SLRUQueue protected_queue;
size_t current_size = 0;
size_t current_protected_size = 0;
const size_t max_size;
size_t current_size = 0;
const size_t max_protected_size;
const size_t max_size;
WeightFunction weight_function;
MappedPtr getImpl(const Key & key, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock) override
{
@ -169,12 +179,12 @@ private:
removeOverflow(probationary_queue, max_size, current_size);
}
void removeOverflow(SLRUQueue & queue, const size_t max_size, size_t & current_size)
void removeOverflow(SLRUQueue & queue, const size_t max_weight_size, size_t & current_weight_size)
{
size_t current_weight_lost = 0;
size_t queue_size = queue.size();
while (current_size > max_size && queue_size > 1)
while (current_weight_size > max_weight_size && queue_size > 1)
{
const Key & key = queue.front();
@ -185,9 +195,9 @@ private:
abort();
}
const auto & cell = it->second;
auto & cell = it->second;
current_size -= cell.size;
current_weight_size -= cell.size;
if (cell.is_protected)
{
@ -213,6 +223,9 @@ private:
}
}
/// Override this method if you want to track how much weight was lost in removeOverflow method.
virtual void onRemoveOverflowWeightLoss(size_t /*weight_loss*/) override {}
};

View File

@ -0,0 +1,119 @@
#include <iomanip>
#include <iostream>
#include <gtest/gtest.h>
#include <Common/SLRUCache.h>
TEST(SLRUCache, set)
{
using SimpleSLRUCache = DB::SLRUCache<int, int>;
auto slru_cache = SimpleSLRUCache(/*max_protected_size=*/5, /*max_total_size=*/10);
slru_cache.set(1, std::make_shared<int>(2));
slru_cache.set(2, std::make_shared<int>(3));
auto w = slru_cache.weight();
auto n = slru_cache.count();
ASSERT_EQ(w, 2);
ASSERT_EQ(n, 2);
}
TEST(SLRUCache, update)
{
using SimpleSLRUCache = DB::SLRUCache<int, int>;
auto slru_cache = SimpleSLRUCache(/*max_protected_size=*/5, /*max_total_size=*/10);
slru_cache.set(1, std::make_shared<int>(2));
slru_cache.set(1, std::make_shared<int>(3));
auto val = slru_cache.get(1);
ASSERT_TRUE(val != nullptr);
ASSERT_TRUE(*val == 3);
}
TEST(SLRUCache, get)
{
using SimpleSLRUCache = DB::SLRUCache<int, int>;
auto slru_cache = SimpleSLRUCache(/*max_protected_size=*/5, /*max_total_size=*/10);
slru_cache.set(1, std::make_shared<int>(2));
slru_cache.set(2, std::make_shared<int>(3));
SimpleSLRUCache::MappedPtr value = slru_cache.get(1);
ASSERT_TRUE(value != nullptr);
ASSERT_EQ(*value, 2);
value = slru_cache.get(2);
ASSERT_TRUE(value != nullptr);
ASSERT_EQ(*value, 3);
}
struct ValueWeight
{
size_t operator()(const size_t & x) const { return x; }
};
TEST(SLRUCache, evictOnWeight)
{
using SimpleSLRUCache = DB::SLRUCache<int, size_t, std::hash<int>, ValueWeight>;
auto slru_cache = SimpleSLRUCache(/*max_protected_size=*/5, /*max_total_size=*/10);
slru_cache.set(1, std::make_shared<size_t>(2));
slru_cache.set(2, std::make_shared<size_t>(3));
slru_cache.set(3, std::make_shared<size_t>(4));
slru_cache.set(3, std::make_shared<size_t>(5));
auto n = slru_cache.count();
ASSERT_EQ(n, 2);
auto w = slru_cache.weight();
ASSERT_EQ(w, 9);
auto value = slru_cache.get(1);
ASSERT_TRUE(value == nullptr);
value = slru_cache.get(2);
ASSERT_TRUE(value == nullptr);
}
TEST(SLRUCache, evictFromProtectedPart)
{
using SimpleSLRUCache = DB::SLRUCache<int, size_t, std::hash<int>, ValueWeight>;
auto slru_cache = SimpleSLRUCache(/*max_protected_size=*/5, /*max_total_size=*/10);
slru_cache.set(1, std::make_shared<size_t>(2));
slru_cache.set(1, std::make_shared<size_t>(2));
slru_cache.set(2, std::make_shared<size_t>(5));
slru_cache.set(2, std::make_shared<size_t>(5));
slru_cache.set(3, std::make_shared<size_t>(5));
auto value = slru_cache.get(1);
ASSERT_TRUE(value == nullptr);
}
TEST(SLRUCache, evictStreamProtected)
{
using SimpleSLRUCache = DB::SLRUCache<int, size_t, std::hash<int>, ValueWeight>;
auto slru_cache = SimpleSLRUCache(/*max_protected_size=*/5, /*max_total_size=*/10);
slru_cache.set(1, std::make_shared<size_t>(2));
slru_cache.set(1, std::make_shared<size_t>(2));
slru_cache.set(2, std::make_shared<size_t>(3));
slru_cache.set(2, std::make_shared<size_t>(3));
for (int key = 3; key < 10; ++key) {
slru_cache.set(key, std::make_shared<size_t>(1 + key % 5));
}
auto value = slru_cache.get(1);
ASSERT_TRUE(value != nullptr);
ASSERT_EQ(*value, 2);
value = slru_cache.get(2);
ASSERT_TRUE(value != nullptr);
ASSERT_EQ(*value, 3);
}
TEST(SLRUCache, getOrSet)
{
using SimpleSLRUCache = DB::SLRUCache<int, size_t, std::hash<int>, ValueWeight>;
auto slru_cache = SimpleSLRUCache(/*max_protected_size=*/5, /*max_total_size=*/10);
size_t x = 5;
auto load_func = [&] { return std::make_shared<size_t>(x); };
auto [value, loaded] = slru_cache.getOrSet(1, load_func);
ASSERT_TRUE(value != nullptr);
ASSERT_TRUE(*value == 5);
}