add AggregateFunctionRetention

This commit is contained in:
sundy-li 2018-08-16 11:11:35 +08:00
parent 57aa1f9726
commit 63d74978d8
6 changed files with 228 additions and 0 deletions

View File

@ -0,0 +1,30 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionRetention.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
namespace DB
{
namespace
{
AggregateFunctionPtr createAggregateFunctionRetention(const std::string & name, const DataTypes & arguments, const Array & params)
{
assertNoParameters(name, params);
if (arguments.size() > AggregateFunctionRetentionData::max_events )
throw Exception("Too many event arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return std::make_shared<AggregateFunctionRetention>(arguments);
}
}
void registerAggregateFunctionRetention(AggregateFunctionFactory & factory)
{
factory.registerFunction("retention", createAggregateFunctionRetention, AggregateFunctionFactory::CaseInsensitive);
}
}

View File

@ -0,0 +1,149 @@
#pragma once
#include <iostream>
#include <sstream>
#include <unordered_set>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeArray.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Common/ArenaAllocator.h>
#include <Common/typeid_cast.h>
#include <ext/range.h>
#include <bitset>
#include <AggregateFunctions/IAggregateFunction.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int TOO_MANY_ARGUMENTS_FOR_FUNCTION;
}
struct AggregateFunctionRetentionData
{
static constexpr auto max_events = 32;
using Events = std::bitset<max_events>;
Events events;
void add(UInt8 event)
{
events.set(event);
}
void merge(const AggregateFunctionRetentionData & other)
{
events |= other.events;
}
void serialize(WriteBuffer & buf) const
{
UInt32 event_value = events.to_ulong();
writeBinary(event_value, buf);
}
void deserialize(ReadBuffer & buf)
{
UInt32 event_value;
readBinary(event_value, buf);
events = event_value;
}
};
/**
* The max size of events is 32, that's enough for retention analytics
*
* Usage:
* - retention(cond1, cond2, cond3, ....)
* - returns [cond1_flag, cond1_flag && cond2_flag, cond1_flag && cond3_flag, ...]
*/
class AggregateFunctionRetention final
: public IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>
{
private:
UInt8 events_size;
public:
String getName() const override
{
return "retention";
}
AggregateFunctionRetention(const DataTypes & arguments)
{
for (const auto i : ext::range(0, arguments.size()))
{
auto cond_arg = arguments[i].get();
if (!typeid_cast<const DataTypeUInt8 *>(cond_arg))
throw Exception{"Illegal type " + cond_arg->getName() + " of argument " + toString(i) + " of aggregate function "
+ getName() + ", must be UInt8",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
events_size = arguments.size();
}
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt8>());
}
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
{
for (const auto i : ext::range(0, events_size))
{
auto event = static_cast<const ColumnVector<UInt8> *>(columns[i])->getData()[row_num];
if (event)
{
this->data(place).add(i);
break;
}
}
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
this->data(place).serialize(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
{
this->data(place).deserialize(buf);
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
auto & data_to = static_cast<ColumnArray &>(to).getData();
auto & offsets_to = static_cast<ColumnArray &>(to).getOffsets();
const auto first_flag = this->data(place).events.test(0);
data_to.insert(first_flag ? Field(static_cast<UInt64>(1)) : Field(static_cast<UInt64>(0)));
for (const auto i : ext::range(1, events_size))
{
if (first_flag && this->data(place).events.test(i))
data_to.insert(Field(static_cast<UInt64>(1)));
else data_to.insert(Field(static_cast<UInt64>(0)));
}
offsets_to.push_back(offsets_to.size() == 0 ? events_size : offsets_to.back() + events_size);
}
const char * getHeaderFilePath() const override
{
return __FILE__;
}
};
}

View File

@ -34,6 +34,7 @@ void registerAggregateFunctionCombinatorMerge(AggregateFunctionCombinatorFactory
void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionHistogram(AggregateFunctionFactory & factory);
void registerAggregateFunctionRetention(AggregateFunctionFactory & factory);
void registerAggregateFunctions()
{
@ -59,6 +60,7 @@ void registerAggregateFunctions()
registerAggregateFunctionsBitwise(factory);
registerAggregateFunctionsMaxIntersections(factory);
registerAggregateFunctionHistogram(factory);
registerAggregateFunctionRetention(factory);
}
{

View File

@ -0,0 +1,4 @@
80
80 50
80 60
80 50 60

View File

@ -0,0 +1,13 @@
DROP TABLE IF EXISTS retention_test;
CREATE TABLE retention_test(date Date, uid Int32)ENGINE = Memory;
INSERT INTO retention_test SELECT '2018-08-06', number FROM numbers(80);
INSERT INTO retention_test SELECT '2018-08-07', number FROM numbers(50);
INSERT INTO retention_test SELECT '2018-08-08', number FROM numbers(60);
SELECT sum(r[1]) as r1 FROM (SELECT uid, retention(date = '2018-08-06') AS r FROM retention_test WHERE date IN ('2018-08-06') GROUP BY uid);
SELECT sum(r[1]) as r1, sum(r[2]) as r2 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-07') AS r FROM retention_test WHERE date IN ('2018-08-06', '2018-08-07') GROUP BY uid);
SELECT sum(r[1]) as r1, sum(r[2]) as r2 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-08') AS r FROM retention_test WHERE date IN ('2018-08-06', '2018-08-08') GROUP BY uid);
SELECT sum(r[1]) as r1, sum(r[2]) as r2, sum(r[3]) as r3 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-07', date = '2018-08-08') AS r FROM retention_test WHERE date IN ('2018-08-06', '2018-08-07', '2018-08-08') GROUP BY uid);
DROP TABLE retention_test;

View File

@ -85,6 +85,36 @@ ORDER BY level
Simply, the level value could only be 0, 1, 2, 3, it means the maxium event action stage that one user could reach.
## retention(cond1, cond2, ...)
Retention refers to the ability of a company or product to retain its customers over some specified periods.
`cond1`, `cond2` ... is from one to 32 arguments of type UInt8 that indicate whether a certain condition was met for the event
Example:
Consider you are doing a website analytics, intend to calculate the retention of customers
This could be easily calculate by `retention`
```
SELECT
sum(r[1]) AS r1,
sum(r[2]) AS r2,
sum(r[3]) AS r3
FROM
(
SELECT
uid,
retention(date = '2018-08-10', date = '2018-08-11', date = '2018-08-12') AS r
FROM events
WHERE date IN ('2018-08-10', '2018-08-11', '2018-08-12')
GROUP BY uid
)
```
Simply, `r1` means the number of unique visitors who met the `cond1` condition, `r2` means the number of unique visitors who met `cond1` and `cond2` conditions, `r3` means the number of unique visitors who met `cond1` and `cond3` conditions.
## uniqUpTo(N)(x)