diff --git a/src/AggregateFunctions/ReservoirSamplerDeterministic.h b/src/AggregateFunctions/ReservoirSamplerDeterministic.h index 3097070c651..f0f926ce31e 100644 --- a/src/AggregateFunctions/ReservoirSamplerDeterministic.h +++ b/src/AggregateFunctions/ReservoirSamplerDeterministic.h @@ -39,8 +39,8 @@ namespace ErrorCodes namespace detail { -const size_t DEFAULT_SAMPLE_COUNT = 8192; -const auto MAX_SKIP_DEGREE = sizeof(UInt32) * 8; + const size_t DEFAULT_MAX_SAMPLE_SIZE = 8192; + const auto MAX_SKIP_DEGREE = sizeof(UInt32) * 8; } /// What if there is not a single value - throw an exception, or return 0 or NaN in the case of double? @@ -50,6 +50,7 @@ enum class ReservoirSamplerDeterministicOnEmpty RETURN_NAN_OR_ZERO, }; + template class ReservoirSamplerDeterministic @@ -60,8 +61,8 @@ class ReservoirSamplerDeterministic } public: - ReservoirSamplerDeterministic(const size_t sample_count_ = DEFAULT_SAMPLE_COUNT) - : sample_count{sample_count_} + ReservoirSamplerDeterministic(const size_t max_sample_size_ = detail::DEFAULT_MAX_SAMPLE_SIZE) + : max_sample_size{max_sample_size_} { } @@ -131,8 +132,8 @@ public: void merge(const ReservoirSamplerDeterministic & b) { - if (sample_count != b.sample_count) - throw Poco::Exception("Cannot merge ReservoirSamplerDeterministic's with different sample_count"); + if (max_sample_size != b.max_sample_size) + throw Poco::Exception("Cannot merge ReservoirSamplerDeterministic's with different max sample count"); sorted = false; if (b.skip_degree > skip_degree) @@ -150,11 +151,16 @@ public: void read(DB::ReadBuffer & buf) { - DB::readIntBinary(sample_count, buf); + size_t size = 0; + DB::readIntBinary(size, buf); DB::readIntBinary(total_values, buf); - samples.resize(std::min(total_values, sample_count)); - for (size_t i = 0; i < samples.size(); ++i) + /// Compatibility with old versions. + if (size > total_values) + size = total_values; + + samples.resize(size); + for (size_t i = 0; i < size; ++i) DB::readPODBinary(samples[i], buf); sorted = false; @@ -162,10 +168,11 @@ public: void write(DB::WriteBuffer & buf) const { - DB::writeIntBinary(sample_count, buf); + size_t size = samples.size(); + DB::writeIntBinary(size, buf); DB::writeIntBinary(total_values, buf); - for (size_t i = 0; i < std::min(sample_count, total_values); ++i) + for (size_t i = 0; i < size; ++i) DB::writePODBinary(samples[i], buf); } @@ -174,18 +181,19 @@ private: using Element = std::pair; using Array = DB::PODArray; - size_t sample_count; - size_t total_values{}; - bool sorted{}; + const size_t max_sample_size; /// Maximum amount of stored values. + size_t total_values = 0; /// How many values were inserted (regardless if they remain in sample or not). + bool sorted = false; Array samples; - UInt8 skip_degree{}; + UInt8 skip_degree = 0; /// The number N determining that we save only one per 2^N elements in average. void insertImpl(const T & v, const UInt32 hash) { - /// @todo why + 1? I don't quite recall - while (samples.size() + 1 >= sample_count) + /// Make a room for plus one element. + while (samples.size() >= max_sample_size) { - if (++skip_degree > detail::MAX_SKIP_DEGREE) + ++skip_degree; + if (skip_degree > detail::MAX_SKIP_DEGREE) throw DB::Exception{"skip_degree exceeds maximum value", DB::ErrorCodes::MEMORY_LIMIT_EXCEEDED}; thinOut(); } @@ -195,35 +203,17 @@ private: void thinOut() { - auto size = samples.size(); - for (size_t i = 0; i < size;) - { - if (!good(samples[i].second)) - { - /// swap current element with the last one - std::swap(samples[size - 1], samples[i]); - --size; - } - else - ++i; - } - - if (size != samples.size()) - { - samples.resize(size); - sorted = false; - } + samples.resize(std::distance(samples.begin(), + std::remove_if(samples.begin(), samples.end(), [this](const auto & elem){ return !good(elem.second); }))); + sorted = false; } void sortIfNeeded() { if (sorted) return; + std::sort(samples.begin(), samples.end(), [](const auto & lhs, const auto & rhs) { return lhs.first < rhs.first; }); sorted = true; - std::sort(samples.begin(), samples.end(), [] (const std::pair & lhs, const std::pair & rhs) - { - return lhs.first < rhs.first; - }); } template