Check for array size overflow in topK #14452

This commit is contained in:
Alexey Milovidov 2020-09-04 04:05:57 +03:00
parent 4f9df21d3e
commit 1cee6d5a31
4 changed files with 8 additions and 6 deletions

View File

@ -85,12 +85,12 @@ AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const
load_factor = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[1]); load_factor = applyVisitor(FieldVisitorConvertToNumber<UInt64>(), params[1]);
if (load_factor < 1) if (load_factor < 1)
throw Exception("Too small parameter for aggregate function " + name + ". Minimum: 1", throw Exception("Too small parameter 'load_factor' for aggregate function " + name + ". Minimum: 1",
ErrorCodes::ARGUMENT_OUT_OF_BOUND); ErrorCodes::ARGUMENT_OUT_OF_BOUND);
} }
if (k > TOP_K_MAX_SIZE) if (k > TOP_K_MAX_SIZE || load_factor > TOP_K_MAX_SIZE || k * load_factor > TOP_K_MAX_SIZE)
throw Exception("Too large parameter for aggregate function " + name + ". Maximum: " + toString(TOP_K_MAX_SIZE), throw Exception("Too large parameter(s) for aggregate function " + name + ". Maximum: " + toString(TOP_K_MAX_SIZE),
ErrorCodes::ARGUMENT_OUT_OF_BOUND); ErrorCodes::ARGUMENT_OUT_OF_BOUND);
if (k == 0) if (k == 0)

View File

@ -147,16 +147,17 @@ public:
{ {
// Increase weight of a key that already exists // Increase weight of a key that already exists
auto hash = counter_map.hash(key); auto hash = counter_map.hash(key);
auto counter = findCounter(key, hash);
if (counter) if (auto counter = findCounter(key, hash); counter)
{ {
counter->count += increment; counter->count += increment;
counter->error += error; counter->error += error;
percolate(counter); percolate(counter);
return; return;
} }
// Key doesn't exist, but can fit in the top K // Key doesn't exist, but can fit in the top K
else if (unlikely(size() < capacity())) if (unlikely(size() < capacity()))
{ {
auto c = new Counter(arena.emplace(key), increment, error, hash); auto c = new Counter(arena.emplace(key), increment, error, hash);
push(c); push(c);

View File

@ -0,0 +1 @@
SELECT length(topKWeighted(2, -9223372036854775808)(number, 1025)) FROM system.numbers; -- { serverError 69 }