mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 15:42:02 +00:00
add aggregate function sequenceCount [#METR-17427]
This commit is contained in:
parent
d43a8cce1f
commit
acb10923de
@ -124,7 +124,7 @@ struct AggregateFunctionSequenceMatchData final
|
||||
/// Max number of iterations to match the pattern against a sequence, exception thrown when exceeded
|
||||
constexpr auto sequence_match_max_iterations = 1000000;
|
||||
|
||||
class AggregateFunctionSequenceMatch final : public IAggregateFunctionHelper<AggregateFunctionSequenceMatchData>
|
||||
class AggregateFunctionSequenceMatch : public IAggregateFunctionHelper<AggregateFunctionSequenceMatchData>
|
||||
{
|
||||
public:
|
||||
static bool sufficientArgs(const std::size_t arg_count) { return arg_count >= 3; }
|
||||
@ -218,7 +218,14 @@ public:
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
const_cast<Data &>(data(place)).sort();
|
||||
static_cast<ColumnUInt8 &>(to).getData().push_back(match(place));
|
||||
|
||||
const auto & data_ref = data(place);
|
||||
|
||||
const auto events_begin = std::begin(data_ref.eventsList);
|
||||
const auto events_end = std::end(data_ref.eventsList);
|
||||
auto events_it = events_begin;
|
||||
|
||||
static_cast<ColumnUInt8 &>(to).getData().push_back(match(events_it, events_end));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -233,21 +240,6 @@ private:
|
||||
TimeGreater
|
||||
};
|
||||
|
||||
static std::string to_string(const PatternActionType type)
|
||||
{
|
||||
static const std::map<PatternActionType, std::string> map{
|
||||
{ PatternActionType::SpecificEvent, "SpecificEvent" },
|
||||
{ PatternActionType::AnyEvent, "AnyEvent" },
|
||||
{ PatternActionType::KleeneStar, "KleeneStar" },
|
||||
{ PatternActionType::TimeLessOrEqual, "TimeLessOrEqual" },
|
||||
{ PatternActionType::TimeLess, "TimeLess", },
|
||||
{ PatternActionType::TimeGreaterOrEqual, "TimeGreaterOrEqual" },
|
||||
{ PatternActionType::TimeGreater, "TimeGreater" }
|
||||
};
|
||||
|
||||
return map.find(type)->second;
|
||||
}
|
||||
|
||||
struct PatternAction final
|
||||
{
|
||||
PatternActionType type;
|
||||
@ -353,18 +345,15 @@ private:
|
||||
this->actions = std::move(actions);
|
||||
}
|
||||
|
||||
bool match(const ConstAggregateDataPtr & place) const
|
||||
protected:
|
||||
template <typename T1, typename T2>
|
||||
bool match(T1 & events_it, const T2 events_end) const
|
||||
{
|
||||
const auto action_begin = std::begin(actions);
|
||||
const auto action_end = std::end(actions);
|
||||
auto action_it = action_begin;
|
||||
|
||||
const auto & data_ref = data(place);
|
||||
const auto events_begin = std::begin(data_ref.eventsList);
|
||||
const auto events_end = std::end(data_ref.eventsList);
|
||||
auto events_it = events_begin;
|
||||
|
||||
auto base_it = events_begin;
|
||||
auto base_it = events_it;
|
||||
|
||||
/// an iterator to action plus an iterator to row in events list plus timestamp at the start of sequence
|
||||
using backtrack_info = std::tuple<decltype(action_it), decltype(events_it), decltype(base_it)>;
|
||||
@ -392,11 +381,6 @@ private:
|
||||
std::size_t i = 0;
|
||||
while (action_it != action_end && events_it != events_end)
|
||||
{
|
||||
// std::cout << "start_timestamp " << base_it->first << "; ";
|
||||
// std::cout << "elapsed " << (events_it->first - base_it->first) << "; ";
|
||||
// std::cout << "action " << (action_it - action_begin) << " { " << to_string(action_it->type) << ' ' << action_it->extra << " }; ";
|
||||
// std::cout << "symbol " << (events_it - events_begin) << " { " << events_it->first << ' ' << events_it->second.to_ulong() << " }" << std::endl;
|
||||
|
||||
if (action_it->type == PatternActionType::SpecificEvent)
|
||||
{
|
||||
if (events_it->second.test(action_it->extra))
|
||||
@ -492,9 +476,40 @@ private:
|
||||
return action_it == action_end;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string pattern;
|
||||
std::size_t arg_count;
|
||||
PatternActions actions;
|
||||
};
|
||||
|
||||
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceMatch
|
||||
{
|
||||
public:
|
||||
String getName() const override { return "sequenceCount"; }
|
||||
|
||||
DataTypePtr getReturnType() const override { return new DataTypeUInt64; }
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
const_cast<Data &>(data(place)).sort();
|
||||
static_cast<ColumnUInt64 &>(to).getData().push_back(count(place));
|
||||
}
|
||||
|
||||
private:
|
||||
UInt64 count(const ConstAggregateDataPtr & place) const
|
||||
{
|
||||
const auto & data_ref = data(place);
|
||||
|
||||
const auto events_begin = std::begin(data_ref.eventsList);
|
||||
const auto events_end = std::end(data_ref.eventsList);
|
||||
auto events_it = events_begin;
|
||||
|
||||
std::size_t count = 0;
|
||||
while (events_it != events_end && match(events_it, events_end))
|
||||
++count;
|
||||
|
||||
return count;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -562,6 +562,13 @@ AggregateFunctionPtr AggregateFunctionFactory::get(const String & name, const Da
|
||||
|
||||
return new AggregateFunctionSequenceMatch;
|
||||
}
|
||||
else if (name == "sequenceCount")
|
||||
{
|
||||
if (!AggregateFunctionSequenceCount::sufficientArgs(argument_types.size()))
|
||||
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
return new AggregateFunctionSequenceCount;
|
||||
}
|
||||
else if (name == "varSamp")
|
||||
{
|
||||
if (argument_types.size() != 1)
|
||||
@ -743,6 +750,7 @@ const AggregateFunctionFactory::FunctionNames & AggregateFunctionFactory::getFun
|
||||
"quantileDeterministic",
|
||||
"quantilesDeterministic",
|
||||
"sequenceMatch",
|
||||
"sequenceCount",
|
||||
"varSamp",
|
||||
"varPop",
|
||||
"stddevSamp",
|
||||
|
Loading…
Reference in New Issue
Block a user