Make sequenceMatch support other unsigned integer types

This commit is contained in:
sundy-li 2019-05-20 12:02:54 +08:00
parent bc99be0f10
commit 6843254ce4
4 changed files with 91 additions and 74 deletions

View File

@ -1,6 +1,12 @@
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionSequenceMatch.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <ext/range.h>
namespace DB
{
@ -12,32 +18,58 @@ namespace ErrorCodes
namespace
{
AggregateFunctionPtr createAggregateFunctionSequenceCount(const std::string & name, const DataTypes & argument_types, const Array & params)
template <template<typename,typename> class AggregateFunction, template <typename> class Data>
AggregateFunctionPtr createAggregateFunctionSequenceBase(const std::string & name, const DataTypes & argument_types, const Array & params)
{
if (params.size() != 1)
throw Exception{"Aggregate function " + name + " requires exactly one parameter.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
String pattern = params.front().safeGet<std::string>();
return std::make_shared<AggregateFunctionSequenceCount>(argument_types, params, pattern);
}
const auto arg_count = argument_types.size();
AggregateFunctionPtr createAggregateFunctionSequenceMatch(const std::string & name, const DataTypes & argument_types, const Array & params)
{
if (params.size() != 1)
throw Exception{"Aggregate function " + name + " requires exactly one parameter.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
if (arg_count < 3)
throw Exception{"Aggregate function " + name + " requires at least 3 arguments.",
ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
if (arg_count - 1 > max_events)
throw Exception{"Aggregate function " + name + " supports up to "
+ toString(max_events) + " event arguments.",
ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION};
const auto time_arg = argument_types.front().get();
for (const auto i : ext::range(1, arg_count))
{
const auto cond_arg = argument_types[i].get();
if (!isUInt8(cond_arg))
throw Exception{"Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1)
+ " of aggregate function " + name + ", must be UInt8",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
String pattern = params.front().safeGet<std::string>();
return std::make_shared<AggregateFunctionSequenceMatch>(argument_types, params, pattern);
AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunction, Data>(*argument_types[0], argument_types, params, pattern));
if (res)
return res;
WhichDataType which(argument_types.front().get());
if (which.isDateTime())
return std::make_shared<AggregateFunction<DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>>>(argument_types, params, pattern);
else if (which.isDate())
return std::make_shared<AggregateFunction<DataTypeDate::FieldType, Data<DataTypeDate::FieldType>>>(argument_types, params, pattern);
throw Exception{"Illegal type " + time_arg->getName() + " of first argument of aggregate function "
+ name + ", must be DateTime",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
}
void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory & factory)
{
factory.registerFunction("sequenceMatch", createAggregateFunctionSequenceMatch);
factory.registerFunction("sequenceCount", createAggregateFunctionSequenceCount);
factory.registerFunction("sequenceMatch", createAggregateFunctionSequenceBase<AggregateFunctionSequenceMatch, AggregateFunctionSequenceMatchData>);
factory.registerFunction("sequenceCount", createAggregateFunctionSequenceBase<AggregateFunctionSequenceCount, AggregateFunctionSequenceMatchData>);
}
}

View File

@ -36,11 +36,12 @@ struct ComparePairFirst final
}
};
static constexpr auto max_events = 32;
template <typename T>
struct AggregateFunctionSequenceMatchData final
{
static constexpr auto max_events = 32;
using Timestamp = std::uint32_t;
using Timestamp = T;
using Events = std::bitset<max_events>;
using TimestampEvents = std::pair<Timestamp, Events>;
using Comparator = ComparePairFirst<std::less>;
@ -119,7 +120,7 @@ struct AggregateFunctionSequenceMatchData final
for (size_t i = 0; i < size; ++i)
{
std::uint32_t timestamp;
Timestamp timestamp;
readBinary(timestamp, buf);
UInt64 events;
@ -135,48 +136,23 @@ struct AggregateFunctionSequenceMatchData final
constexpr auto sequence_match_max_iterations = 1000000;
template <typename Derived>
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>
template <typename T, typename Data, typename Derived>
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<Data, Derived>
{
public:
AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern)
: IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>(arguments, params)
: IAggregateFunctionDataHelper<Data, Derived>(arguments, params)
, pattern(pattern)
{
arg_count = arguments.size();
if (!sufficientArgs(arg_count))
throw Exception{"Aggregate function " + derived().getName() + " requires at least 3 arguments.",
ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
if (arg_count - 1 > AggregateFunctionSequenceMatchData::max_events)
throw Exception{"Aggregate function " + derived().getName() + " supports up to " +
toString(AggregateFunctionSequenceMatchData::max_events) + " event arguments.",
ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION};
const auto time_arg = arguments.front().get();
if (!WhichDataType(time_arg).isDateTime())
throw Exception{"Illegal type " + time_arg->getName() + " of first argument of aggregate function "
+ derived().getName() + ", must be DateTime",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
for (const auto i : ext::range(1, arg_count))
{
const auto cond_arg = arguments[i].get();
if (!isUInt8(cond_arg))
throw Exception{"Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1) +
" of aggregate function " + derived().getName() + ", must be UInt8",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
parsePattern();
}
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
{
const auto timestamp = static_cast<const ColumnUInt32 *>(columns[0])->getData()[row_num];
const auto timestamp = static_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
AggregateFunctionSequenceMatchData::Events events;
typename Data::Events events;
for (const auto i : ext::range(1, arg_count))
{
const auto event = static_cast<const ColumnUInt8 *>(columns[i])->getData()[row_num];
@ -227,8 +203,6 @@ private:
static constexpr size_t bytes_on_stack = 64;
using PatternActions = PODArray<PatternAction, bytes_on_stack, AllocatorWithStackMemory<Allocator<false>, bytes_on_stack>>;
static bool sufficientArgs(const size_t arg_count) { return arg_count >= 3; }
Derived & derived() { return static_cast<Derived &>(*this); }
void parsePattern()
@ -340,8 +314,8 @@ protected:
/// This algorithm performs in O(mn) (with m the number of DFA states and N the number
/// of events) with a memory consumption and memory allocations in O(m). It means that
/// if n >>> m (which is expected to be the case), this algorithm can be considered linear.
template <typename T>
bool dfaMatch(T & events_it, const T events_end) const
template <typename EventEntry>
bool dfaMatch(EventEntry & events_it, const EventEntry events_end) const
{
using ActiveStates = std::vector<bool>;
@ -396,8 +370,8 @@ protected:
return active_states.back();
}
template <typename T>
bool backtrackingMatch(T & events_it, const T events_end) const
template <typename EventEntry>
bool backtrackingMatch(EventEntry & events_it, const EventEntry events_end) const
{
const auto action_begin = std::begin(actions);
const auto action_end = std::end(actions);
@ -407,7 +381,7 @@ protected:
auto base_it = events_it;
/// an iterator to action plus an iterator to row in events list plus timestamp at the start of sequence
using backtrack_info = std::tuple<decltype(action_it), T, T>;
using backtrack_info = std::tuple<decltype(action_it), EventEntry, EventEntry>;
std::stack<backtrack_info> back_stack;
/// backtrack if possible
@ -458,7 +432,7 @@ protected:
}
else if (action_it->type == PatternActionType::TimeLessOrEqual)
{
if (events_it->first - base_it->first <= action_it->extra)
if (events_it->first <= static_cast<UInt64>(base_it->first + action_it->extra))
{
/// condition satisfied, move onto next action
back_stack.emplace(action_it, events_it, base_it);
@ -470,7 +444,7 @@ protected:
}
else if (action_it->type == PatternActionType::TimeLess)
{
if (events_it->first - base_it->first < action_it->extra)
if (events_it->first < static_cast<UInt64>(base_it->first < action_it->extra))
{
back_stack.emplace(action_it, events_it, base_it);
base_it = events_it;
@ -481,7 +455,7 @@ protected:
}
else if (action_it->type == PatternActionType::TimeGreaterOrEqual)
{
if (events_it->first - base_it->first >= action_it->extra)
if (events_it->first >= static_cast<UInt64>(base_it->first + action_it->extra))
{
back_stack.emplace(action_it, events_it, base_it);
base_it = events_it;
@ -492,7 +466,7 @@ protected:
}
else if (action_it->type == PatternActionType::TimeGreater)
{
if (events_it->first - base_it->first > action_it->extra)
if (events_it->first > static_cast<UInt64>(base_it->first + action_it->extra))
{
back_stack.emplace(action_it, events_it, base_it);
base_it = events_it;
@ -575,14 +549,14 @@ private:
DFAStates dfa_states;
};
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>
template <typename T, typename Data>
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>
{
public:
AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern)
: AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>(arguments, params, pattern) {}
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern) {}
using AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>::AggregateFunctionSequenceBase;
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceMatch"; }
@ -590,27 +564,27 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const_cast<Data &>(data(place)).sort();
const_cast<Data &>(this->data(place)).sort();
const auto & data_ref = data(place);
const auto & data_ref = this->data(place);
const auto events_begin = std::begin(data_ref.events_list);
const auto events_end = std::end(data_ref.events_list);
auto events_it = events_begin;
bool match = pattern_has_time ? backtrackingMatch(events_it, events_end) : dfaMatch(events_it, events_end);
bool match = this->pattern_has_time ? this->backtrackingMatch(events_it, events_end) : this->dfaMatch(events_it, events_end);
static_cast<ColumnUInt8 &>(to).getData().push_back(match);
}
};
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>
template <typename T, typename Data>
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>
{
public:
AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern)
: AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>(arguments, params, pattern) {}
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern) {}
using AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>::AggregateFunctionSequenceBase;
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceCount"; }
@ -618,21 +592,21 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const_cast<Data &>(data(place)).sort();
const_cast<Data &>(this->data(place)).sort();
static_cast<ColumnUInt64 &>(to).getData().push_back(count(place));
}
private:
UInt64 count(const ConstAggregateDataPtr & place) const
{
const auto & data_ref = data(place);
const auto & data_ref = this->data(place);
const auto events_begin = std::begin(data_ref.events_list);
const auto events_end = std::end(data_ref.events_list);
auto events_it = events_begin;
size_t count = 0;
while (events_it != events_end && backtrackingMatch(events_it, events_end))
while (events_it != events_end && this->backtrackingMatch(events_it, events_end))
++count;
return count;

View File

@ -12,14 +12,24 @@ ENGINE = Memory;
INSERT INTO sequence SELECT 1, number = 0 ? 'A' : (number < 1000000 ? 'B' : 'C'), number FROM numbers(1000001);
SELECT userID
SELECT 'ABC'
FROM sequence
GROUP BY userID
HAVING sequenceMatch('(?1).*(?2).*(?3)')(toDateTime(EventTime), eventType = 'A', eventType = 'B', eventType = 'C');
SELECT userID
SELECT 'ABA'
FROM sequence
GROUP BY userID
HAVING sequenceMatch('(?1).*(?2).*(?3)')(toDateTime(EventTime), eventType = 'A', eventType = 'B', eventType = 'A');
SELECT 'ABBC'
FROM sequence
GROUP BY userID
HAVING sequenceMatch('(?1).*(?2).*(?3).*(?4)')(EventTime, eventType = 'A', eventType = 'B', eventType = 'B',eventType = 'C');
SELECT 'CBA'
FROM sequence
GROUP BY userID
HAVING sequenceMatch('(?1).*(?2).*(?3)')(EventTime, eventType = 'C', eventType = 'B', eventType = 'A');
DROP TABLE sequence;