mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
add AggregateFunctionRetention
This commit is contained in:
parent
57aa1f9726
commit
63d74978d8
30
dbms/src/AggregateFunctions/AggregateFunctionRetention.cpp
Normal file
30
dbms/src/AggregateFunctions/AggregateFunctionRetention.cpp
Normal file
@ -0,0 +1,30 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionRetention.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionRetention(const std::string & name, const DataTypes & arguments, const Array & params)
|
||||
{
|
||||
assertNoParameters(name, params);
|
||||
|
||||
if (arguments.size() > AggregateFunctionRetentionData::max_events )
|
||||
throw Exception("Too many event arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
return std::make_shared<AggregateFunctionRetention>(arguments);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionRetention(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("retention", createAggregateFunctionRetention, AggregateFunctionFactory::CaseInsensitive);
|
||||
}
|
||||
|
||||
}
|
149
dbms/src/AggregateFunctions/AggregateFunctionRetention.h
Normal file
149
dbms/src/AggregateFunctions/AggregateFunctionRetention.h
Normal file
@ -0,0 +1,149 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Common/ArenaAllocator.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <ext/range.h>
|
||||
#include <bitset>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int TOO_MANY_ARGUMENTS_FOR_FUNCTION;
|
||||
}
|
||||
|
||||
struct AggregateFunctionRetentionData
|
||||
{
|
||||
static constexpr auto max_events = 32;
|
||||
|
||||
using Events = std::bitset<max_events>;
|
||||
|
||||
Events events;
|
||||
|
||||
void add(UInt8 event)
|
||||
{
|
||||
events.set(event);
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionRetentionData & other)
|
||||
{
|
||||
events |= other.events;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
UInt32 event_value = events.to_ulong();
|
||||
writeBinary(event_value, buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
UInt32 event_value;
|
||||
readBinary(event_value, buf);
|
||||
events = event_value;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* The max size of events is 32, that's enough for retention analytics
|
||||
*
|
||||
* Usage:
|
||||
* - retention(cond1, cond2, cond3, ....)
|
||||
* - returns [cond1_flag, cond1_flag && cond2_flag, cond1_flag && cond3_flag, ...]
|
||||
*/
|
||||
class AggregateFunctionRetention final
|
||||
: public IAggregateFunctionDataHelper<AggregateFunctionRetentionData, AggregateFunctionRetention>
|
||||
{
|
||||
private:
|
||||
UInt8 events_size;
|
||||
|
||||
public:
|
||||
String getName() const override
|
||||
{
|
||||
return "retention";
|
||||
}
|
||||
|
||||
AggregateFunctionRetention(const DataTypes & arguments)
|
||||
{
|
||||
for (const auto i : ext::range(0, arguments.size()))
|
||||
{
|
||||
auto cond_arg = arguments[i].get();
|
||||
if (!typeid_cast<const DataTypeUInt8 *>(cond_arg))
|
||||
throw Exception{"Illegal type " + cond_arg->getName() + " of argument " + toString(i) + " of aggregate function "
|
||||
+ getName() + ", must be UInt8",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
}
|
||||
|
||||
events_size = arguments.size();
|
||||
}
|
||||
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeUInt8>());
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
{
|
||||
for (const auto i : ext::range(0, events_size))
|
||||
{
|
||||
auto event = static_cast<const ColumnVector<UInt8> *>(columns[i])->getData()[row_num];
|
||||
if (event)
|
||||
{
|
||||
this->data(place).add(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena *) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
auto & data_to = static_cast<ColumnArray &>(to).getData();
|
||||
auto & offsets_to = static_cast<ColumnArray &>(to).getOffsets();
|
||||
|
||||
const auto first_flag = this->data(place).events.test(0);
|
||||
data_to.insert(first_flag ? Field(static_cast<UInt64>(1)) : Field(static_cast<UInt64>(0)));
|
||||
for (const auto i : ext::range(1, events_size))
|
||||
{
|
||||
if (first_flag && this->data(place).events.test(i))
|
||||
data_to.insert(Field(static_cast<UInt64>(1)));
|
||||
else data_to.insert(Field(static_cast<UInt64>(0)));
|
||||
}
|
||||
offsets_to.push_back(offsets_to.size() == 0 ? events_size : offsets_to.back() + events_size);
|
||||
}
|
||||
|
||||
const char * getHeaderFilePath() const override
|
||||
{
|
||||
return __FILE__;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
@ -34,6 +34,7 @@ void registerAggregateFunctionCombinatorMerge(AggregateFunctionCombinatorFactory
|
||||
void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory &);
|
||||
|
||||
void registerAggregateFunctionHistogram(AggregateFunctionFactory & factory);
|
||||
void registerAggregateFunctionRetention(AggregateFunctionFactory & factory);
|
||||
|
||||
void registerAggregateFunctions()
|
||||
{
|
||||
@ -59,6 +60,7 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionsBitwise(factory);
|
||||
registerAggregateFunctionsMaxIntersections(factory);
|
||||
registerAggregateFunctionHistogram(factory);
|
||||
registerAggregateFunctionRetention(factory);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -0,0 +1,4 @@
|
||||
80
|
||||
80 50
|
||||
80 60
|
||||
80 50 60
|
@ -0,0 +1,13 @@
|
||||
DROP TABLE IF EXISTS retention_test;
|
||||
|
||||
CREATE TABLE retention_test(date Date, uid Int32)ENGINE = Memory;
|
||||
INSERT INTO retention_test SELECT '2018-08-06', number FROM numbers(80);
|
||||
INSERT INTO retention_test SELECT '2018-08-07', number FROM numbers(50);
|
||||
INSERT INTO retention_test SELECT '2018-08-08', number FROM numbers(60);
|
||||
|
||||
SELECT sum(r[1]) as r1 FROM (SELECT uid, retention(date = '2018-08-06') AS r FROM retention_test WHERE date IN ('2018-08-06') GROUP BY uid);
|
||||
SELECT sum(r[1]) as r1, sum(r[2]) as r2 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-07') AS r FROM retention_test WHERE date IN ('2018-08-06', '2018-08-07') GROUP BY uid);
|
||||
SELECT sum(r[1]) as r1, sum(r[2]) as r2 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-08') AS r FROM retention_test WHERE date IN ('2018-08-06', '2018-08-08') GROUP BY uid);
|
||||
SELECT sum(r[1]) as r1, sum(r[2]) as r2, sum(r[3]) as r3 FROM (SELECT uid, retention(date = '2018-08-06', date = '2018-08-07', date = '2018-08-08') AS r FROM retention_test WHERE date IN ('2018-08-06', '2018-08-07', '2018-08-08') GROUP BY uid);
|
||||
|
||||
DROP TABLE retention_test;
|
@ -85,6 +85,36 @@ ORDER BY level
|
||||
|
||||
Simply, the level value could only be 0, 1, 2, 3, it means the maxium event action stage that one user could reach.
|
||||
|
||||
## retention(cond1, cond2, ...)
|
||||
|
||||
Retention refers to the ability of a company or product to retain its customers over some specified periods.
|
||||
|
||||
`cond1`, `cond2` ... is from one to 32 arguments of type UInt8 that indicate whether a certain condition was met for the event
|
||||
|
||||
Example:
|
||||
|
||||
Consider you are doing a website analytics, intend to calculate the retention of customers
|
||||
|
||||
This could be easily calculate by `retention`
|
||||
|
||||
```
|
||||
SELECT
|
||||
sum(r[1]) AS r1,
|
||||
sum(r[2]) AS r2,
|
||||
sum(r[3]) AS r3
|
||||
FROM
|
||||
(
|
||||
SELECT
|
||||
uid,
|
||||
retention(date = '2018-08-10', date = '2018-08-11', date = '2018-08-12') AS r
|
||||
FROM events
|
||||
WHERE date IN ('2018-08-10', '2018-08-11', '2018-08-12')
|
||||
GROUP BY uid
|
||||
)
|
||||
```
|
||||
|
||||
Simply, `r1` means the number of unique visitors who met the `cond1` condition, `r2` means the number of unique visitors who met `cond1` and `cond2` conditions, `r3` means the number of unique visitors who met `cond1` and `cond3` conditions.
|
||||
|
||||
|
||||
## uniqUpTo(N)(x)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user