AggregateFunctions: implemented topK(n)

This implements a new function for approximate
computation of the most frequent entries using
Filtered Space Saving with a merge step adapted
from Parallel Space Saving paper.

It works better for cases where GROUP BY x
is impractical due to high cardinality of x,
such as top IP addresses or top search queries.
This commit is contained in:
Marek Vavruša 2017-05-02 14:08:37 -07:00 committed by alexey-milovidov
parent d2d7aaac69
commit 5f1e65b252
11 changed files with 723 additions and 0 deletions

View File

@ -46,6 +46,7 @@ void registerAggregateFunctionsStatistics(AggregateFunctionFactory & factory);
void registerAggregateFunctionSum(AggregateFunctionFactory & factory);
void registerAggregateFunctionsUniq(AggregateFunctionFactory & factory);
void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory & factory);
void registerAggregateFunctionTopK(AggregateFunctionFactory & factory);
void registerAggregateFunctionDebug(AggregateFunctionFactory & factory);
AggregateFunctionPtr createAggregateFunctionArray(AggregateFunctionPtr & nested);
@ -76,6 +77,7 @@ AggregateFunctionFactory::AggregateFunctionFactory()
registerAggregateFunctionSum(*this);
registerAggregateFunctionsUniq(*this);
registerAggregateFunctionUniqUpTo(*this);
registerAggregateFunctionTopK(*this);
registerAggregateFunctionDebug(*this);
}

View File

@ -0,0 +1,70 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionTopK.h>
#include <AggregateFunctions/Helpers.h>
namespace DB
{
namespace
{
/// Substitute return type for Date and DateTime
class AggregateFunctionTopKDate : public AggregateFunctionTopK<DataTypeDate::FieldType>
{
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDate>()); }
};
class AggregateFunctionTopKDateTime : public AggregateFunctionTopK<DataTypeDateTime::FieldType>
{
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeArray>(std::make_shared<DataTypeDateTime>()); }
};
static IAggregateFunction * createWithExtraTypes(const IDataType & argument_type)
{
if (typeid_cast<const DataTypeDate *>(&argument_type)) return new AggregateFunctionTopKDate;
else if (typeid_cast<const DataTypeDateTime *>(&argument_type)) return new AggregateFunctionTopKDateTime;
else
{
/// Check that we can use plain version of AggregateFunctionTopKGeneric
if (typeid_cast<const DataTypeString*>(&argument_type) || typeid_cast<const DataTypeFixedString*>(&argument_type))
return new AggregateFunctionTopKGeneric<true>;
auto * array_type = typeid_cast<const DataTypeArray *>(&argument_type);
if (array_type)
{
auto nested_type = array_type->getNestedType();
if (nested_type->isNumeric() || typeid_cast<DataTypeFixedString *>(nested_type.get()))
return new AggregateFunctionTopKGeneric<true>;
}
return new AggregateFunctionTopKGeneric<false>;
}
}
AggregateFunctionPtr createAggregateFunctionTopK(const std::string & name, const DataTypes & argument_types)
{
if (argument_types.size() != 1)
throw Exception("Incorrect number of arguments for aggregate function " + name,
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
AggregateFunctionPtr res(createWithNumericType<AggregateFunctionTopK>(*argument_types[0]));
if (!res)
res = AggregateFunctionPtr(createWithExtraTypes(*argument_types[0]));
if (!res)
throw Exception("Illegal type " + argument_types[0]->getName() +
" of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
}
void registerAggregateFunctionTopK(AggregateFunctionFactory & factory)
{
factory.registerFunction("topK", createAggregateFunctionTopK);
}
}

View File

@ -0,0 +1,264 @@
#pragma once
#include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeString.h>
#include <Columns/ColumnArray.h>
#include <Common/SpaceSaving.h>
#include <Core/FieldVisitors.h>
#include <AggregateFunctions/AggregateFunctionGroupArray.h>
namespace DB
{
// Allow NxK more space before calculating top K to increase accuracy
#define TOP_K_LOAD_FACTOR 3
#define TOP_K_MAX_SIZE 0xFFFFFF
template <typename T>
struct AggregateFunctionTopKData
{
using Set = SpaceSaving<T, DefaultHash<T>>;
Set value;
};
template <typename T>
class AggregateFunctionTopK
: public IUnaryAggregateFunction<AggregateFunctionTopKData<T>, AggregateFunctionTopK<T>>
{
private:
using State = AggregateFunctionTopKData<T>;
size_t threshold = 10; // Default value if the parameter is not specified.
size_t reserved = TOP_K_LOAD_FACTOR * threshold;
public:
String getName() const override { return "topK"; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeNumber<T>>());
}
void setArgument(const DataTypePtr & argument)
{
}
void setParameters(const Array & params) override
{
if (params.size() != 1)
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
std::size_t k = applyVisitor(FieldVisitorConvertToNumber<size_t>(), params[0]);
if (k > TOP_K_MAX_SIZE)
throw Exception("Too large parameter for aggregate function " + getName() + ". Maximum: " + toString(TOP_K_MAX_SIZE),
ErrorCodes::ARGUMENT_OUT_OF_BOUND);
threshold = k;
reserved = TOP_K_LOAD_FACTOR * k;
}
void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num, Arena *) const
{
auto & set = this->data(place).value;
if (set.capacity() != reserved) {
set.resize(reserved);
}
set.insert(static_cast<const ColumnVector<T> &>(column).getData()[row_num]);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
this->data(place).value.merge(this->data(rhs).value);
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(place).value.write(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
auto & set = this->data(place).value;
set.resize(reserved);
set.read(buf);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets();
const typename State::Set & set = this->data(place).value;
auto resultVec = set.topK(threshold);
size_t size = resultVec.size();
offsets_to.push_back((offsets_to.size() == 0 ? 0 : offsets_to.back()) + size);
typename ColumnVector<T>::Container_t & data_to = static_cast<ColumnVector<T> &>(arr_to.getData()).getData();
size_t old_size = data_to.size();
data_to.resize(old_size + size);
size_t i = 0;
for (auto it = resultVec.begin(); it != resultVec.end(); ++it, ++i)
data_to[old_size + i] = it->key;
}
};
/// Generic implementation, it uses serialized representation as object descriptor.
struct AggregateFunctionTopKGenericData
{
using Set = SpaceSaving<StringRef, StringRefHash>;
Set value;
};
/** Template parameter with true value should be used for columns that store their elements in memory continuously.
* For such columns topK() can be implemented more efficently (especially for small numeric arrays).
*/
template <bool is_plain_column = false>
class AggregateFunctionTopKGeneric : public IUnaryAggregateFunction<AggregateFunctionTopKGenericData, AggregateFunctionTopKGeneric<is_plain_column>>
{
private:
using State = AggregateFunctionTopKGenericData;
DataTypePtr input_data_type;
size_t threshold = 10; // Default value if the parameter is not specified.
size_t reserved = TOP_K_LOAD_FACTOR * threshold;
static StringRef getSerialization(const IColumn & column, size_t row_num, Arena & arena);
static void deserializeAndInsert(StringRef str, IColumn & data_to);
public:
String getName() const override { return "topK"; }
void setArgument(const DataTypePtr & argument)
{
input_data_type = argument;
}
void setParameters(const Array & params) override
{
if (params.size() != 1)
throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
size_t k = applyVisitor(FieldVisitorConvertToNumber<size_t>(), params[0]);
if (k > TOP_K_MAX_SIZE)
throw Exception("Too large parameter for aggregate function " + getName() + ". Maximum: " + toString(TOP_K_MAX_SIZE),
ErrorCodes::ARGUMENT_OUT_OF_BOUND);
threshold = k;
reserved = TOP_K_LOAD_FACTOR * k;
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(input_data_type->clone());
}
bool allocatesMemoryInArena() const override
{
return true;
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(place).value.write(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
{
auto & set = this->data(place).value;
set.resize(reserved);
size_t count = 0;
readVarUInt(count, buf);
for (size_t i = 0; i < count; ++i) {
auto key = readStringBinaryInto(*arena, buf);
UInt64 count, error;
readVarUInt(count, buf);
readVarUInt(error, buf);
set.insert(key, count, error);
}
}
void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num, Arena * arena) const
{
auto & set = this->data(place).value;
if (set.capacity() != reserved) {
set.resize(reserved);
}
StringRef str_serialized = getSerialization(column, row_num, *arena);
if (is_plain_column) {
auto ptr = arena->insert(str_serialized.data, str_serialized.size);
str_serialized = StringRef(ptr, str_serialized.size);
}
set.insert(str_serialized);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
this->data(place).value.merge(this->data(rhs).value);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets();
IColumn & data_to = arr_to.getData();
auto resultVec = this->data(place).value.topK(threshold);
offsets_to.push_back((offsets_to.size() == 0 ? 0 : offsets_to.back()) + resultVec.size());
for (auto & elem : resultVec)
{
deserializeAndInsert(elem.key, data_to);
}
}
};
template <>
inline StringRef AggregateFunctionTopKGeneric<false>::getSerialization(const IColumn & column, size_t row_num, Arena & arena)
{
const char * begin = nullptr;
return column.serializeValueIntoArena(row_num, arena, begin);
}
template <>
inline StringRef AggregateFunctionTopKGeneric<true>::getSerialization(const IColumn & column, size_t row_num, Arena &)
{
return column.getDataAt(row_num);
}
template <>
inline void AggregateFunctionTopKGeneric<false>::deserializeAndInsert(StringRef str, IColumn & data_to)
{
data_to.deserializeAndInsertFromArena(str.data);
}
template <>
inline void AggregateFunctionTopKGeneric<true>::deserializeAndInsert(StringRef str, IColumn & data_to)
{
data_to.insertData(str.data, str.size);
}
#undef TOP_K_MAX_SIZE
#undef TOP_K_LOAD_FACTOR
}

View File

@ -0,0 +1,254 @@
#pragma once
#include <iostream>
#include <list>
#include <vector>
#include <boost/range/adaptor/reversed.hpp>
#include <Common/UInt128.h>
#include <Common/HashTable/Hash.h>
#include <Common/HashTable/HashMap.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/ReadBuffer.h>
#include <IO/ReadHelpers.h>
#include <IO/VarInt.h>
/*
* Implementation of the Filtered Space-Saving for TopK streaming analysis.
* http://www.l2f.inesc-id.pt/~fmmb/wiki/uploads/Work/misnis.ref0a.pdf
* It implements suggested reduce-and-combine algorithm from Parallel Space Saving:
* https://arxiv.org/pdf/1401.0702.pdf
*/
namespace DB
{
template <typename TKey, typename Hash = DefaultHash<TKey>>
class SpaceSaving
{
public:
struct Counter {
Counter() {}
Counter(const TKey & k, UInt64 c = 0, UInt64 e = 0)
: key(k), slot(0), count(c), error(e) {}
void write(DB::WriteBuffer & wb) const
{
DB::writeBinary(key, wb);
DB::writeVarUInt(count, wb);
DB::writeVarUInt(error, wb);
}
void read(DB::ReadBuffer & rb)
{
DB::readBinary(key, rb);
DB::readVarUInt(count, rb);
DB::readVarUInt(error, rb);
}
// greater() taking slot error into account
bool operator >(const Counter &b) const
{
return (count > b.count) || (count == b.count && error < b.error);
}
TKey key;
size_t slot;
UInt64 count, error;
};
// Suggested constants in the paper "Finding top-k elements in data streams", chap 6. equation (24)
SpaceSaving(size_t c = 10) : counterMap(), counterList(), alphaMap(6 * c), cap(c) {}
~SpaceSaving() { destroyElements(); }
inline size_t size() const
{
return counterList.size();
}
inline size_t capacity() const
{
return cap;
}
void resize(size_t c)
{
counterList.reserve(c);
alphaMap.resize(c * 6);
cap = c;
}
Counter * insert(const TKey & key, UInt64 increment = 1, UInt64 error = 0)
{
// Increase weight of a key that already exists
// It uses hashtable for both value mapping as a presence test (c_i != 0)
auto hash = counterMap.hash(key);
auto it = counterMap.find(key, hash);
if (it != counterMap.end()) {
auto c = it->second;
c->count += increment;
c->error += error;
percolate(c);
return c;
}
// Key doesn't exist, but can fit in the top K
if (size() < capacity()) {
auto c = new Counter(key, increment, error);
push(c);
return c;
}
auto min = counterList.back();
auto & alpha = alphaMap[hash % alphaMap.size()];
if (alpha + increment < min->count) {
alpha += increment;
return nullptr;
}
// Erase the current minimum element
auto minHash = counterMap.hash(min->key);
it = counterMap.find(min->key, minHash);
if (it != counterMap.end()) {
auto cell = it.getPtr();
cell->setZero();
}
// Replace minimum with newly inserted element
bool inserted = false;
counterMap.emplace(key, it, inserted, hash);
if (inserted) {
alphaMap[minHash % alphaMap.size()] = min->count;
min->key = key;
min->count = alpha + increment;
min->error = alpha + error;
it->second = min;
percolate(min);
}
return min;
}
/*
* Parallel Space Saving reduction and combine step from:
* https://arxiv.org/pdf/1401.0702.pdf
*/
void merge(const SpaceSaving<TKey, Hash> & rhs)
{
UInt64 m1 = 0, m2 = 0;
if (size() == capacity()) {
m1 = counterList.back()->count;
}
if (rhs.size() == rhs.capacity()) {
m2 = rhs.counterList.back()->count;
}
/*
* Updated algorithm to mutate current table in place
* without mutating rhs table or creating new one
* in the first step we expect that no elements overlap
* and in the second sweep we correct the error if they do.
*/
if (m2 > 0) {
for (auto c : counterList) {
c->count += m2;
c->error += m2;
}
}
// The list is sorted in descending order, we have to scan in reverse
for (auto c : boost::adaptors::reverse(rhs.counterList)) {
if (counterMap.find(c->key) != counterMap.end()) {
// Subtract m2 previously added, guaranteed not negative
insert(c->key, c->count - m2, c->error - m2);
} else {
// Counters not monitored in S1
insert(c->key, c->count + m1, c->error + m1);
}
}
}
std::vector<Counter> topK(size_t k) const
{
std::vector<Counter> res;
for (auto c : counterList) {
res.push_back(*c);
if (res.size() == k) {
break;
}
}
return res;
}
void write(DB::WriteBuffer & wb) const
{
DB::writeVarUInt(size(), wb);
for (auto c : counterList) {
c->write(wb);
}
for (auto a : alphaMap) {
DB::writeVarUInt(a, wb);
}
}
void read(DB::ReadBuffer & rb)
{
destroyElements();
size_t count = 0;
DB::readVarUInt(count, rb);
for (size_t i = 0; i < count; ++i) {
auto c = new Counter();
c->read(rb);
push(c);
}
for (size_t i = 0; i < capacity() * 6; ++i) {
UInt64 alpha = 0;
DB::readVarUInt(alpha, rb);
alphaMap.push_back(alpha);
}
}
protected:
void push(Counter * c) {
c->slot = counterList.size();
counterList.push_back(c);
counterMap[c->key] = c;
percolate(c);
}
// This is equivallent to one step of bubble sort
void percolate(Counter * c) {
while (c->slot > 0) {
auto next = counterList[c->slot - 1];
if (*c > *next) {
std::swap(next->slot, c->slot);
std::swap(counterList[next->slot], counterList[c->slot]);
} else {
break;
}
}
}
private:
void destroyElements() {
for (auto c : counterList) {
delete c;
}
counterMap.clear();
counterList.clear();
alphaMap.clear();
}
HashMap<TKey, Counter *, Hash> counterMap;
std::vector<Counter *> counterList;
std::vector<UInt64> alphaMap;
size_t cap;
};
};

View File

@ -54,3 +54,6 @@ target_link_libraries (thread_pool dbms)
add_executable (array_cache array_cache.cpp)
target_link_libraries (array_cache dbms)
add_executable (space_saving space_saving.cpp)
target_link_libraries (space_saving dbms)

View File

@ -0,0 +1,105 @@
#include <iostream>
#include <iomanip>
#include <string>
#include <map>
#include <Core/StringRef.h>
#include <Common/SpaceSaving.h>
int main(int argc, char ** argv)
{
{
using Cont = DB::SpaceSaving<int>;
Cont first(10);
/* Test biased insertion */
for (int i = 0; i < 200; ++i) {
first.insert(i);
int k = i % 5; // Bias towards 0-4
first.insert(k);
}
/* Test whether the biased elements are retained */
std::map<int, UInt64> expect;
for (int i = 0; i < 5; ++i) {
expect[i] = 41;
}
for (auto x : first.topK(5)) {
if (expect[x.key] != x.count) {
std::cerr << "key: " << x.key << " value: " << x.count << " expected: " << expect[x.key] << std::endl;
} else {
std::cout << "key: " << x.key << " value: " << x.count << std::endl;
}
expect.erase(x.key);
}
if (!expect.empty()) {
std::cerr << "expected to find all heavy hitters" << std::endl;
}
/* Create another table and test merging */
Cont second(10);
for (int i = 0; i < 200; ++i) {
first.insert(i);
}
for (int i = 0; i < 5; ++i) {
expect[i] = 42;
}
first.merge(second);
for (auto x : first.topK(5)) {
if (expect[x.key] != x.count) {
std::cerr << "key: " << x.key << " value: " << x.count << " expected: " << expect[x.key] << std::endl;
} else {
std::cout << "key: " << x.key << " value: " << x.count << std::endl;
}
expect.erase(x.key);
}
}
{
/* Same test for string keys */
using Cont = DB::SpaceSaving<StringRef, StringRefHash>;
Cont cont(10);
std::vector<std::string> refs;
for (int i = 0; i < 400; ++i) {
refs.push_back(std::to_string(i));
cont.insert(StringRef(refs.back()));
refs.push_back(std::to_string(i % 5)); // Bias towards 0-4
cont.insert(StringRef(refs.back()));
}
// The hashing is going to be more lossy
// Expect at least ~ 10% count
std::map<std::string, UInt64> expect;
for (int i = 0; i < 5; ++i) {
expect[std::to_string(i)] = 38;
}
for (auto x : cont.topK(5)) {
auto key = x.key.toString();
if (x.count < expect[key]) {
std::cerr << "key: " << key << " value: " << x.count << " expected: " << expect[key] << std::endl;
} else {
std::cout << "key: " << key << " value: " << x.count << std::endl;
}
expect.erase(key);
}
if (!expect.empty()) {
std::cerr << "expected to find all heavy hitters" << std::endl;
abort();
}
}
return 0;
}

View File

@ -583,6 +583,7 @@ inline typename std::enable_if<std::is_arithmetic<T>::value, void>::type
writeBinary(const T & x, WriteBuffer & buf) { writePODBinary(x, buf); }
inline void writeBinary(const String & x, WriteBuffer & buf) { writeStringBinary(x, buf); }
inline void writeBinary(const StringRef & x, WriteBuffer & buf) { writeStringBinary(x, buf); }
inline void writeBinary(const uint128 & x, WriteBuffer & buf) { writePODBinary(x, buf); }
inline void writeBinary(const LocalDate & x, WriteBuffer & buf) { writePODBinary(x, buf); }
inline void writeBinary(const LocalDateTime & x, WriteBuffer & buf) { writePODBinary(x, buf); }

View File

@ -0,0 +1 @@
[0,1,2,3,4,5,6,7,8,9]

View File

@ -0,0 +1 @@
SELECT topK(10)(n) FROM (SELECT if(number % 100 < 10, number % 10, number) AS n FROM system.numbers LIMIT 100000);

View File

@ -6418,6 +6418,17 @@ Usage example:
Problem: Generate a report that shows only keywords that produced at least 5 unique users.
Solution: Write in the query <span class="inline-example">GROUP BY SearchPhrase HAVING uniqUpTo(4)(UserID) >= 5</span>
==topK(N)(x)==
Returns the K most frequent argument values as an array sorted by their relative frequency.
Recommended for use with small Ns, up to 10. The maximum N value is 65536.
For the state of an aggregate function, it uses approximately the amount of memory equal to K * (the size of the key + 16) for counters, and 48 * N bytes for alpha value map.
Usage example:
Problem: Generate a report that shows top 5 frequent queries.
Solution: Write in the query <span class="inline-example">SELECT topK(5)(SearchPhrase)</span>
==Aggregate function combinators==

View File

@ -6534,6 +6534,17 @@ cond1, cond2 ... - от одного до 32 аргументов типа UInt8
Задача: показывать в отчёте только поисковые фразы, по которым было хотя бы 5 уникальных посетителей.
Решение: пишем в запросе %%GROUP BY SearchPhrase HAVING uniqUpTo(4)(UserID) &gt;= 5%%
==topK(N)(x)==
Returns the K most frequent argument values as an array sorted by their relative frequency.
Recommended for use with small Ns, up to 10. The maximum N value is 65536.
For the state of an aggregate function, it uses approximately the amount of memory equal to K * (the size of the key + 16) for counters, and 48 * N bytes for alpha value map.
Usage example:
Problem: Generate a report that shows top 5 frequent queries.
Solution: Write in the query <span class="inline-example">SELECT topK(5)(SearchPhrase)</span>
==Комбинаторы агрегатных функций==