Fix undefined-behavior in ReservoirSamplerDeterministic.h

This commit is contained in:
Vitaly Baranov 2021-02-18 16:32:01 +03:00
parent c704a8cc45
commit 556dc81ab9

View File

@ -56,7 +56,7 @@ class ReservoirSamplerDeterministic
{
bool good(const UInt32 hash)
{
return hash == ((hash >> skip_degree) << skip_degree);
return !(hash & skip_mask);
}
public:
@ -135,11 +135,8 @@ public:
throw Poco::Exception("Cannot merge ReservoirSamplerDeterministic's with different max sample size");
sorted = false;
if (b.skip_degree > skip_degree)
{
skip_degree = b.skip_degree;
thinOut();
}
if (skip_degree < b.skip_degree)
setSkipDegree(b.skip_degree);
for (const auto & sample : b.samples)
if (good(sample.second))
@ -184,22 +181,39 @@ private:
size_t total_values = 0; /// How many values were inserted (regardless if they remain in sample or not).
bool sorted = false;
Array samples;
UInt8 skip_degree = 0; /// The number N determining that we save only one per 2^N elements in average.
/// The number N determining that we store only one per 2^N elements in average.
UInt8 skip_degree = 0;
/// skip_mask is calculated as (2 ^ skip_degree - 1). We store an element only if (hash & skip_mask) == 0.
/// For example, if skip_degree==0 then skip_mask==0 means we store each element;
/// if skip_degree==1 then skip_mask==0b0001 means we store one per 2 elements in average;
/// if skip_degree==4 then skip_mask==0b1111 means we store one per 16 elements in average.
UInt32 skip_mask = 0;
void insertImpl(const T & v, const UInt32 hash)
{
/// Make a room for plus one element.
while (samples.size() >= max_sample_size)
{
++skip_degree;
if (skip_degree > detail::MAX_SKIP_DEGREE)
throw DB::Exception{"skip_degree exceeds maximum value", DB::ErrorCodes::MEMORY_LIMIT_EXCEEDED};
thinOut();
}
setSkipDegree(skip_degree + 1);
samples.emplace_back(v, hash);
}
void setSkipDegree(UInt8 skip_degree_)
{
if (skip_degree_ == skip_degree)
return;
if (skip_degree_ > detail::MAX_SKIP_DEGREE)
throw DB::Exception{"skip_degree exceeds maximum value", DB::ErrorCodes::MEMORY_LIMIT_EXCEEDED};
skip_degree = skip_degree_;
if (skip_degree == detail::MAX_SKIP_DEGREE)
skip_mask = static_cast<UInt32>(-1);
else
skip_mask = (1 << skip_degree) - 1;
thinOut();
}
void thinOut()
{
samples.resize(std::distance(samples.begin(),