add aggregate function sequenceCount [#METR-17427]

This commit is contained in:
Andrey Mironov 2015-08-21 18:57:26 +03:00
parent d43a8cce1f
commit acb10923de
2 changed files with 52 additions and 29 deletions

View File

@ -124,7 +124,7 @@ struct AggregateFunctionSequenceMatchData final
/// Max number of iterations to match the pattern against a sequence, exception thrown when exceeded
constexpr auto sequence_match_max_iterations = 1000000;
class AggregateFunctionSequenceMatch final : public IAggregateFunctionHelper<AggregateFunctionSequenceMatchData>
class AggregateFunctionSequenceMatch : public IAggregateFunctionHelper<AggregateFunctionSequenceMatchData>
{
public:
static bool sufficientArgs(const std::size_t arg_count) { return arg_count >= 3; }
@ -218,7 +218,14 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const_cast<Data &>(data(place)).sort();
static_cast<ColumnUInt8 &>(to).getData().push_back(match(place));
const auto & data_ref = data(place);
const auto events_begin = std::begin(data_ref.eventsList);
const auto events_end = std::end(data_ref.eventsList);
auto events_it = events_begin;
static_cast<ColumnUInt8 &>(to).getData().push_back(match(events_it, events_end));
}
private:
@ -233,21 +240,6 @@ private:
TimeGreater
};
static std::string to_string(const PatternActionType type)
{
static const std::map<PatternActionType, std::string> map{
{ PatternActionType::SpecificEvent, "SpecificEvent" },
{ PatternActionType::AnyEvent, "AnyEvent" },
{ PatternActionType::KleeneStar, "KleeneStar" },
{ PatternActionType::TimeLessOrEqual, "TimeLessOrEqual" },
{ PatternActionType::TimeLess, "TimeLess", },
{ PatternActionType::TimeGreaterOrEqual, "TimeGreaterOrEqual" },
{ PatternActionType::TimeGreater, "TimeGreater" }
};
return map.find(type)->second;
}
struct PatternAction final
{
PatternActionType type;
@ -353,18 +345,15 @@ private:
this->actions = std::move(actions);
}
bool match(const ConstAggregateDataPtr & place) const
protected:
template <typename T1, typename T2>
bool match(T1 & events_it, const T2 events_end) const
{
const auto action_begin = std::begin(actions);
const auto action_end = std::end(actions);
auto action_it = action_begin;
const auto & data_ref = data(place);
const auto events_begin = std::begin(data_ref.eventsList);
const auto events_end = std::end(data_ref.eventsList);
auto events_it = events_begin;
auto base_it = events_begin;
auto base_it = events_it;
/// an iterator to action plus an iterator to row in events list plus timestamp at the start of sequence
using backtrack_info = std::tuple<decltype(action_it), decltype(events_it), decltype(base_it)>;
@ -392,11 +381,6 @@ private:
std::size_t i = 0;
while (action_it != action_end && events_it != events_end)
{
// std::cout << "start_timestamp " << base_it->first << "; ";
// std::cout << "elapsed " << (events_it->first - base_it->first) << "; ";
// std::cout << "action " << (action_it - action_begin) << " { " << to_string(action_it->type) << ' ' << action_it->extra << " }; ";
// std::cout << "symbol " << (events_it - events_begin) << " { " << events_it->first << ' ' << events_it->second.to_ulong() << " }" << std::endl;
if (action_it->type == PatternActionType::SpecificEvent)
{
if (events_it->second.test(action_it->extra))
@ -492,9 +476,40 @@ private:
return action_it == action_end;
}
private:
std::string pattern;
std::size_t arg_count;
PatternActions actions;
};
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceMatch
{
public:
String getName() const override { return "sequenceCount"; }
DataTypePtr getReturnType() const override { return new DataTypeUInt64; }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const_cast<Data &>(data(place)).sort();
static_cast<ColumnUInt64 &>(to).getData().push_back(count(place));
}
private:
UInt64 count(const ConstAggregateDataPtr & place) const
{
const auto & data_ref = data(place);
const auto events_begin = std::begin(data_ref.eventsList);
const auto events_end = std::end(data_ref.eventsList);
auto events_it = events_begin;
std::size_t count = 0;
while (events_it != events_end && match(events_it, events_end))
++count;
return count;
}
};
}

View File

@ -562,6 +562,13 @@ AggregateFunctionPtr AggregateFunctionFactory::get(const String & name, const Da
return new AggregateFunctionSequenceMatch;
}
else if (name == "sequenceCount")
{
if (!AggregateFunctionSequenceCount::sufficientArgs(argument_types.size()))
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return new AggregateFunctionSequenceCount;
}
else if (name == "varSamp")
{
if (argument_types.size() != 1)
@ -743,6 +750,7 @@ const AggregateFunctionFactory::FunctionNames & AggregateFunctionFactory::getFun
"quantileDeterministic",
"quantilesDeterministic",
"sequenceMatch",
"sequenceCount",
"varSamp",
"varPop",
"stddevSamp",