mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
Make sequenceMatch support other unsigned integer types
This commit is contained in:
parent
bc99be0f10
commit
6843254ce4
@ -1,6 +1,12 @@
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionSequenceMatch.h>
|
||||
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
|
||||
#include <ext/range.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
@ -12,32 +18,58 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionSequenceCount(const std::string & name, const DataTypes & argument_types, const Array & params)
|
||||
template <template<typename,typename> class AggregateFunction, template <typename> class Data>
|
||||
AggregateFunctionPtr createAggregateFunctionSequenceBase(const std::string & name, const DataTypes & argument_types, const Array & params)
|
||||
{
|
||||
if (params.size() != 1)
|
||||
throw Exception{"Aggregate function " + name + " requires exactly one parameter.",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
|
||||
|
||||
String pattern = params.front().safeGet<std::string>();
|
||||
return std::make_shared<AggregateFunctionSequenceCount>(argument_types, params, pattern);
|
||||
}
|
||||
const auto arg_count = argument_types.size();
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionSequenceMatch(const std::string & name, const DataTypes & argument_types, const Array & params)
|
||||
{
|
||||
if (params.size() != 1)
|
||||
throw Exception{"Aggregate function " + name + " requires exactly one parameter.",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
|
||||
if (arg_count < 3)
|
||||
throw Exception{"Aggregate function " + name + " requires at least 3 arguments.",
|
||||
ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
|
||||
|
||||
if (arg_count - 1 > max_events)
|
||||
throw Exception{"Aggregate function " + name + " supports up to "
|
||||
+ toString(max_events) + " event arguments.",
|
||||
ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION};
|
||||
|
||||
const auto time_arg = argument_types.front().get();
|
||||
|
||||
for (const auto i : ext::range(1, arg_count))
|
||||
{
|
||||
const auto cond_arg = argument_types[i].get();
|
||||
if (!isUInt8(cond_arg))
|
||||
throw Exception{"Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1)
|
||||
+ " of aggregate function " + name + ", must be UInt8",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
}
|
||||
|
||||
String pattern = params.front().safeGet<std::string>();
|
||||
return std::make_shared<AggregateFunctionSequenceMatch>(argument_types, params, pattern);
|
||||
|
||||
AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunction, Data>(*argument_types[0], argument_types, params, pattern));
|
||||
if (res)
|
||||
return res;
|
||||
|
||||
WhichDataType which(argument_types.front().get());
|
||||
if (which.isDateTime())
|
||||
return std::make_shared<AggregateFunction<DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>>>(argument_types, params, pattern);
|
||||
else if (which.isDate())
|
||||
return std::make_shared<AggregateFunction<DataTypeDate::FieldType, Data<DataTypeDate::FieldType>>>(argument_types, params, pattern);
|
||||
|
||||
throw Exception{"Illegal type " + time_arg->getName() + " of first argument of aggregate function "
|
||||
+ name + ", must be DateTime",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("sequenceMatch", createAggregateFunctionSequenceMatch);
|
||||
factory.registerFunction("sequenceCount", createAggregateFunctionSequenceCount);
|
||||
factory.registerFunction("sequenceMatch", createAggregateFunctionSequenceBase<AggregateFunctionSequenceMatch, AggregateFunctionSequenceMatchData>);
|
||||
factory.registerFunction("sequenceCount", createAggregateFunctionSequenceBase<AggregateFunctionSequenceCount, AggregateFunctionSequenceMatchData>);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -36,11 +36,12 @@ struct ComparePairFirst final
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr auto max_events = 32;
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionSequenceMatchData final
|
||||
{
|
||||
static constexpr auto max_events = 32;
|
||||
|
||||
using Timestamp = std::uint32_t;
|
||||
using Timestamp = T;
|
||||
using Events = std::bitset<max_events>;
|
||||
using TimestampEvents = std::pair<Timestamp, Events>;
|
||||
using Comparator = ComparePairFirst<std::less>;
|
||||
@ -119,7 +120,7 @@ struct AggregateFunctionSequenceMatchData final
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
std::uint32_t timestamp;
|
||||
Timestamp timestamp;
|
||||
readBinary(timestamp, buf);
|
||||
|
||||
UInt64 events;
|
||||
@ -135,48 +136,23 @@ struct AggregateFunctionSequenceMatchData final
|
||||
constexpr auto sequence_match_max_iterations = 1000000;
|
||||
|
||||
|
||||
template <typename Derived>
|
||||
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>
|
||||
template <typename T, typename Data, typename Derived>
|
||||
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<Data, Derived>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>(arguments, params)
|
||||
: IAggregateFunctionDataHelper<Data, Derived>(arguments, params)
|
||||
, pattern(pattern)
|
||||
{
|
||||
arg_count = arguments.size();
|
||||
|
||||
if (!sufficientArgs(arg_count))
|
||||
throw Exception{"Aggregate function " + derived().getName() + " requires at least 3 arguments.",
|
||||
ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
|
||||
|
||||
if (arg_count - 1 > AggregateFunctionSequenceMatchData::max_events)
|
||||
throw Exception{"Aggregate function " + derived().getName() + " supports up to " +
|
||||
toString(AggregateFunctionSequenceMatchData::max_events) + " event arguments.",
|
||||
ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION};
|
||||
|
||||
const auto time_arg = arguments.front().get();
|
||||
if (!WhichDataType(time_arg).isDateTime())
|
||||
throw Exception{"Illegal type " + time_arg->getName() + " of first argument of aggregate function "
|
||||
+ derived().getName() + ", must be DateTime",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
||||
for (const auto i : ext::range(1, arg_count))
|
||||
{
|
||||
const auto cond_arg = arguments[i].get();
|
||||
if (!isUInt8(cond_arg))
|
||||
throw Exception{"Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1) +
|
||||
" of aggregate function " + derived().getName() + ", must be UInt8",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
}
|
||||
|
||||
parsePattern();
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
{
|
||||
const auto timestamp = static_cast<const ColumnUInt32 *>(columns[0])->getData()[row_num];
|
||||
const auto timestamp = static_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
|
||||
|
||||
AggregateFunctionSequenceMatchData::Events events;
|
||||
typename Data::Events events;
|
||||
for (const auto i : ext::range(1, arg_count))
|
||||
{
|
||||
const auto event = static_cast<const ColumnUInt8 *>(columns[i])->getData()[row_num];
|
||||
@ -227,8 +203,6 @@ private:
|
||||
static constexpr size_t bytes_on_stack = 64;
|
||||
using PatternActions = PODArray<PatternAction, bytes_on_stack, AllocatorWithStackMemory<Allocator<false>, bytes_on_stack>>;
|
||||
|
||||
static bool sufficientArgs(const size_t arg_count) { return arg_count >= 3; }
|
||||
|
||||
Derived & derived() { return static_cast<Derived &>(*this); }
|
||||
|
||||
void parsePattern()
|
||||
@ -340,8 +314,8 @@ protected:
|
||||
/// This algorithm performs in O(mn) (with m the number of DFA states and N the number
|
||||
/// of events) with a memory consumption and memory allocations in O(m). It means that
|
||||
/// if n >>> m (which is expected to be the case), this algorithm can be considered linear.
|
||||
template <typename T>
|
||||
bool dfaMatch(T & events_it, const T events_end) const
|
||||
template <typename EventEntry>
|
||||
bool dfaMatch(EventEntry & events_it, const EventEntry events_end) const
|
||||
{
|
||||
using ActiveStates = std::vector<bool>;
|
||||
|
||||
@ -396,8 +370,8 @@ protected:
|
||||
return active_states.back();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool backtrackingMatch(T & events_it, const T events_end) const
|
||||
template <typename EventEntry>
|
||||
bool backtrackingMatch(EventEntry & events_it, const EventEntry events_end) const
|
||||
{
|
||||
const auto action_begin = std::begin(actions);
|
||||
const auto action_end = std::end(actions);
|
||||
@ -407,7 +381,7 @@ protected:
|
||||
auto base_it = events_it;
|
||||
|
||||
/// an iterator to action plus an iterator to row in events list plus timestamp at the start of sequence
|
||||
using backtrack_info = std::tuple<decltype(action_it), T, T>;
|
||||
using backtrack_info = std::tuple<decltype(action_it), EventEntry, EventEntry>;
|
||||
std::stack<backtrack_info> back_stack;
|
||||
|
||||
/// backtrack if possible
|
||||
@ -458,7 +432,7 @@ protected:
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeLessOrEqual)
|
||||
{
|
||||
if (events_it->first - base_it->first <= action_it->extra)
|
||||
if (events_it->first <= static_cast<UInt64>(base_it->first + action_it->extra))
|
||||
{
|
||||
/// condition satisfied, move onto next action
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
@ -470,7 +444,7 @@ protected:
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeLess)
|
||||
{
|
||||
if (events_it->first - base_it->first < action_it->extra)
|
||||
if (events_it->first < static_cast<UInt64>(base_it->first < action_it->extra))
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
@ -481,7 +455,7 @@ protected:
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeGreaterOrEqual)
|
||||
{
|
||||
if (events_it->first - base_it->first >= action_it->extra)
|
||||
if (events_it->first >= static_cast<UInt64>(base_it->first + action_it->extra))
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
@ -492,7 +466,7 @@ protected:
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeGreater)
|
||||
{
|
||||
if (events_it->first - base_it->first > action_it->extra)
|
||||
if (events_it->first > static_cast<UInt64>(base_it->first + action_it->extra))
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
@ -575,14 +549,14 @@ private:
|
||||
DFAStates dfa_states;
|
||||
};
|
||||
|
||||
|
||||
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern)
|
||||
: AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>(arguments, params, pattern) {}
|
||||
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern) {}
|
||||
|
||||
using AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>::AggregateFunctionSequenceBase;
|
||||
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>::AggregateFunctionSequenceBase;
|
||||
|
||||
String getName() const override { return "sequenceMatch"; }
|
||||
|
||||
@ -590,27 +564,27 @@ public:
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
const_cast<Data &>(data(place)).sort();
|
||||
const_cast<Data &>(this->data(place)).sort();
|
||||
|
||||
const auto & data_ref = data(place);
|
||||
const auto & data_ref = this->data(place);
|
||||
|
||||
const auto events_begin = std::begin(data_ref.events_list);
|
||||
const auto events_end = std::end(data_ref.events_list);
|
||||
auto events_it = events_begin;
|
||||
|
||||
bool match = pattern_has_time ? backtrackingMatch(events_it, events_end) : dfaMatch(events_it, events_end);
|
||||
bool match = this->pattern_has_time ? this->backtrackingMatch(events_it, events_end) : this->dfaMatch(events_it, events_end);
|
||||
static_cast<ColumnUInt8 &>(to).getData().push_back(match);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern)
|
||||
: AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>(arguments, params, pattern) {}
|
||||
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern) {}
|
||||
|
||||
using AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>::AggregateFunctionSequenceBase;
|
||||
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>::AggregateFunctionSequenceBase;
|
||||
|
||||
String getName() const override { return "sequenceCount"; }
|
||||
|
||||
@ -618,21 +592,21 @@ public:
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
const_cast<Data &>(data(place)).sort();
|
||||
const_cast<Data &>(this->data(place)).sort();
|
||||
static_cast<ColumnUInt64 &>(to).getData().push_back(count(place));
|
||||
}
|
||||
|
||||
private:
|
||||
UInt64 count(const ConstAggregateDataPtr & place) const
|
||||
{
|
||||
const auto & data_ref = data(place);
|
||||
const auto & data_ref = this->data(place);
|
||||
|
||||
const auto events_begin = std::begin(data_ref.events_list);
|
||||
const auto events_end = std::end(data_ref.events_list);
|
||||
auto events_it = events_begin;
|
||||
|
||||
size_t count = 0;
|
||||
while (events_it != events_end && backtrackingMatch(events_it, events_end))
|
||||
while (events_it != events_end && this->backtrackingMatch(events_it, events_end))
|
||||
++count;
|
||||
|
||||
return count;
|
||||
|
@ -1 +1,2 @@
|
||||
1
|
||||
ABC
|
||||
ABBC
|
||||
|
@ -12,14 +12,24 @@ ENGINE = Memory;
|
||||
|
||||
INSERT INTO sequence SELECT 1, number = 0 ? 'A' : (number < 1000000 ? 'B' : 'C'), number FROM numbers(1000001);
|
||||
|
||||
SELECT userID
|
||||
SELECT 'ABC'
|
||||
FROM sequence
|
||||
GROUP BY userID
|
||||
HAVING sequenceMatch('(?1).*(?2).*(?3)')(toDateTime(EventTime), eventType = 'A', eventType = 'B', eventType = 'C');
|
||||
|
||||
SELECT userID
|
||||
SELECT 'ABA'
|
||||
FROM sequence
|
||||
GROUP BY userID
|
||||
HAVING sequenceMatch('(?1).*(?2).*(?3)')(toDateTime(EventTime), eventType = 'A', eventType = 'B', eventType = 'A');
|
||||
|
||||
SELECT 'ABBC'
|
||||
FROM sequence
|
||||
GROUP BY userID
|
||||
HAVING sequenceMatch('(?1).*(?2).*(?3).*(?4)')(EventTime, eventType = 'A', eventType = 'B', eventType = 'B',eventType = 'C');
|
||||
|
||||
SELECT 'CBA'
|
||||
FROM sequence
|
||||
GROUP BY userID
|
||||
HAVING sequenceMatch('(?1).*(?2).*(?3)')(EventTime, eventType = 'C', eventType = 'B', eventType = 'A');
|
||||
|
||||
DROP TABLE sequence;
|
||||
|
Loading…
Reference in New Issue
Block a user