support unsigned integer type in windowFunnel Aggregate function

This commit is contained in:
sundy-li 2019-05-17 19:17:52 +08:00
parent bc99be0f10
commit 760bc5708d
4 changed files with 52 additions and 23 deletions

View File

@ -2,6 +2,8 @@
#include <AggregateFunctions/AggregateFunctionWindowFunnel.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
namespace DB
@ -10,6 +12,7 @@ namespace DB
namespace
{
template <template <typename> class Data>
AggregateFunctionPtr createAggregateFunctionWindowFunnel(const std::string & name, const DataTypes & arguments, const Array & params)
{
if (params.size() != 1)
@ -18,17 +21,27 @@ AggregateFunctionPtr createAggregateFunctionWindowFunnel(const std::string & nam
if (arguments.size() < 2)
throw Exception("Aggregate function " + name + " requires one timestamp argument and at least one event condition.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (arguments.size() > AggregateFunctionWindowFunnelData::max_events + 1)
if (arguments.size() > max_events + 1)
throw Exception("Too many event arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return std::make_shared<AggregateFunctionWindowFunnel>(arguments, params);
AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunctionWindowFunnel, Data>(*arguments[0], arguments, params));
WhichDataType which(arguments.front().get());
if (res)
return res;
else if (which.isDate())
return std::make_shared<AggregateFunctionWindowFunnel<DataTypeDate::FieldType, Data<DataTypeDate::FieldType>>>(arguments, params);
else if (which.isDateTime())
return std::make_shared<AggregateFunctionWindowFunnel<DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>>>(arguments, params);
throw Exception{"Illegal type " + arguments.front().get()->getName() + " of first argument of aggregate function "
+ name + ", must be Number, Date, DateTime", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
}
void registerAggregateFunctionWindowFunnel(AggregateFunctionFactory & factory)
{
factory.registerFunction("windowFunnel", createAggregateFunctionWindowFunnel, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("windowFunnel", createAggregateFunctionWindowFunnel<AggregateFunctionWindowFunnelData>, AggregateFunctionFactory::CaseInsensitive);
}
}

View File

@ -13,7 +13,6 @@
#include <AggregateFunctions/IAggregateFunction.h>
namespace DB
{
@ -33,10 +32,11 @@ struct ComparePairFirst final
}
};
static constexpr auto max_events = 32;
template <typename T>
struct AggregateFunctionWindowFunnelData
{
static constexpr auto max_events = 32;
using TimestampEvent = std::pair<UInt32, UInt8>;
using TimestampEvent = std::pair<T, UInt8>;
static constexpr size_t bytes_on_stack = 64;
using TimestampEvents = PODArray<TimestampEvent, bytes_on_stack, AllocatorWithStackMemory<Allocator<false>, bytes_on_stack>>;
@ -51,7 +51,7 @@ struct AggregateFunctionWindowFunnelData
return events_list.size();
}
void add(UInt32 timestamp, UInt8 event)
void add(T timestamp, UInt8 event)
{
// Since most events should have already been sorted by timestamp.
if (sorted && events_list.size() > 0 && events_list.back().first > timestamp)
@ -119,7 +119,7 @@ struct AggregateFunctionWindowFunnelData
events_list.clear();
events_list.reserve(size);
UInt32 timestamp;
T timestamp;
UInt8 event;
for (size_t i = 0; i < size; ++i)
@ -137,11 +137,12 @@ struct AggregateFunctionWindowFunnelData
* Usage:
* - windowFunnel(window)(timestamp, cond1, cond2, cond3, ....)
*/
template <typename T, typename Data>
class AggregateFunctionWindowFunnel final
: public IAggregateFunctionDataHelper<AggregateFunctionWindowFunnelData, AggregateFunctionWindowFunnel>
: public IAggregateFunctionDataHelper<Data, AggregateFunctionWindowFunnel<T, Data>>
{
private:
UInt32 window;
UInt64 window;
UInt8 events_size;
@ -149,22 +150,24 @@ private:
// The level path must be 1---2---3---...---check_events_size, find the max event level that statisfied the path in the sliding window.
// If found, returns the max event level, else return 0.
// The Algorithm complexity is O(n).
UInt8 getEventLevel(const AggregateFunctionWindowFunnelData & data) const
UInt8 getEventLevel(const Data & data) const
{
if (data.size() == 0)
return 0;
if (events_size == 1)
return 1;
const_cast<AggregateFunctionWindowFunnelData &>(data).sort();
const_cast<Data &>(data).sort();
// events_timestamp stores the timestamp that latest i-th level event happen withing time window after previous level event.
// timestamp defaults to -1, which unsigned timestamp value never meet
std::vector<Int32> events_timestamp(events_size, -1);
/// events_timestamp stores the timestamp that latest i-th level event happen withing time window after previous level event.
/// timestamp defaults to -1, which unsigned timestamp value never meet
/// there may be some bugs when UInt64 type timstamp overflows Int64, but it works on most cases.
std::vector<Int64> events_timestamp(events_size, -1);
for (const auto & pair : data.events_list)
{
const auto & timestamp = pair.first;
const T & timestamp = pair.first;
const auto & event_idx = pair.second - 1;
if (event_idx == 0)
events_timestamp[0] = timestamp;
else if (events_timestamp[event_idx - 1] >= 0 && timestamp <= events_timestamp[event_idx - 1] + window)
@ -189,13 +192,8 @@ public:
}
AggregateFunctionWindowFunnel(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper<AggregateFunctionWindowFunnelData, AggregateFunctionWindowFunnel>(arguments, params)
: IAggregateFunctionDataHelper<Data, AggregateFunctionWindowFunnel<T, Data>>(arguments, params)
{
const auto time_arg = arguments.front().get();
if (!WhichDataType(time_arg).isDateTime() && !WhichDataType(time_arg).isUInt32())
throw Exception{"Illegal type " + time_arg->getName() + " of first argument of aggregate function " + getName()
+ ", must be DateTime or UInt32", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
for (const auto i : ext::range(1, arguments.size()))
{
auto cond_arg = arguments[i].get();
@ -217,7 +215,7 @@ public:
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
{
const auto timestamp = static_cast<const ColumnVector<UInt32> *>(columns[0])->getData()[row_num];
const auto timestamp = static_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
// reverse iteration and stable sorting are needed for events that are qualified by more than one condition.
for (auto i = events_size; i > 0; --i)
{

View File

@ -11,3 +11,8 @@
1
1
1
1
1
1
1
1

View File

@ -27,5 +27,18 @@ select 1 = windowFunnel(10000)(timestamp, event = 1008, event = 1001) from funne
select 5 = windowFunnel(4)(timestamp, event = 1003, event = 1004, event = 1005, event = 1006, event = 1007) from funnel_test2;
select 4 = windowFunnel(4)(timestamp, event <= 1007, event >= 1002, event <= 1006, event >= 1004) from funnel_test2;
drop table if exists funnel_test_u64;
create table funnel_test_u64 (uid UInt32 default 1,timestamp UInt64, event UInt32) engine=Memory;
insert into funnel_test_u64(timestamp, event) values ( 1e14 + 1 ,1001),(1e14 + 2,1002),(1e14 + 3,1003),(1e14 + 4,1004),(1e14 + 5,1005),(1e14 + 6,1006),(1e14 + 7,1007),(1e14 + 8,1008);
select 5 = windowFunnel(4)(timestamp, event = 1003, event = 1004, event = 1005, event = 1006, event = 1007) from funnel_test_u64;
select 2 = windowFunnel(10000)(timestamp, event = 1001, event = 1008) from funnel_test_u64;
select 1 = windowFunnel(10000)(timestamp, event = 1008, event = 1001) from funnel_test_u64;
select 5 = windowFunnel(4)(timestamp, event = 1003, event = 1004, event = 1005, event = 1006, event = 1007) from funnel_test_u64;
select 4 = windowFunnel(4)(timestamp, event <= 1007, event >= 1002, event <= 1006, event >= 1004) from funnel_test_u64;
drop table funnel_test;
drop table funnel_test2;
drop table funnel_test_u64;