Add sequenceFirstNode

This commit is contained in:
achimbab 2021-02-04 16:15:04 +09:00
parent 19dd09ea8e
commit 2cc69893f2
4 changed files with 221 additions and 15 deletions

View File

@ -23,15 +23,27 @@ namespace
{
template <typename TYPE>
inline AggregateFunctionPtr createAggregateFunctionSequenceNextNodeImpl(const DataTypePtr data_type, const DataTypes & argument_types, bool descending_order)
inline AggregateFunctionPtr createAggregateFunctionSequenceNodeImpl(const DataTypePtr data_type, const DataTypes & argument_types, bool descending_order)
{
if (argument_types.size() == 2)
{
// If the number of arguments of sequenceNextNode is 2, the sequenceNextNode acts as sequenceFirstNode.
if (descending_order)
return std::make_shared<SequenceFirstNodeImpl<TYPE, NodeString, true>>(data_type);
else
return std::make_shared<SequenceFirstNodeImpl<TYPE, NodeString, false>>(data_type);
}
else
{
if (descending_order)
return std::make_shared<SequenceNextNodeImpl<TYPE, NodeString, true>>(data_type, argument_types);
else
return std::make_shared<SequenceNextNodeImpl<TYPE, NodeString, false>>(data_type, argument_types);
}
}
AggregateFunctionPtr createAggregateFunctionSequenceNextNode(const std::string & name, const DataTypes & argument_types, const Array & parameters)
template <UInt64 MaxArgs>
AggregateFunctionPtr createAggregateFunctionSequenceNode(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
bool descending_order = false;
@ -47,9 +59,9 @@ AggregateFunctionPtr createAggregateFunctionSequenceNextNode(const std::string &
throw Exception("Incorrect number of parameters for aggregate function " + name + ", should be 1",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (argument_types.size() < 3)
throw Exception("Aggregate function " + name + " requires at least three arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
else if (argument_types.size() > 2 + 31)
if (argument_types.size() < 2)
throw Exception("Aggregate function " + name + " requires at least two arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
else if (argument_types.size() > MaxArgs)
throw Exception("Aggregate function " + name + " requires at most 34(timestamp, value_column, 31 events) arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (const auto i : ext::range(2, argument_types.size()))
@ -73,17 +85,17 @@ AggregateFunctionPtr createAggregateFunctionSequenceNextNode(const std::string &
WhichDataType timestamp_type(argument_types[0].get());
if (timestamp_type.idx == TypeIndex::UInt8)
return createAggregateFunctionSequenceNextNodeImpl<UInt8>(data_type, argument_types, descending_order);
return createAggregateFunctionSequenceNodeImpl<UInt8>(data_type, argument_types, descending_order);
if (timestamp_type.idx == TypeIndex::UInt16)
return createAggregateFunctionSequenceNextNodeImpl<UInt16>(data_type, argument_types, descending_order);
return createAggregateFunctionSequenceNodeImpl<UInt16>(data_type, argument_types, descending_order);
if (timestamp_type.idx == TypeIndex::UInt32)
return createAggregateFunctionSequenceNextNodeImpl<UInt32>(data_type, argument_types, descending_order);
return createAggregateFunctionSequenceNodeImpl<UInt32>(data_type, argument_types, descending_order);
if (timestamp_type.idx == TypeIndex::UInt64)
return createAggregateFunctionSequenceNextNodeImpl<UInt64>(data_type, argument_types, descending_order);
return createAggregateFunctionSequenceNodeImpl<UInt64>(data_type, argument_types, descending_order);
if (timestamp_type.isDate())
return createAggregateFunctionSequenceNextNodeImpl<DataTypeDate::FieldType>(data_type, argument_types, descending_order);
return createAggregateFunctionSequenceNodeImpl<DataTypeDate::FieldType>(data_type, argument_types, descending_order);
if (timestamp_type.isDateTime())
return createAggregateFunctionSequenceNextNodeImpl<DataTypeDateTime::FieldType>(data_type, argument_types, descending_order);
return createAggregateFunctionSequenceNodeImpl<DataTypeDateTime::FieldType>(data_type, argument_types, descending_order);
throw Exception{"Illegal type " + argument_types.front().get()->getName()
+ " of first argument of aggregate function " + name + ", must be Unsigned Number, Date, DateTime",
@ -95,7 +107,8 @@ AggregateFunctionPtr createAggregateFunctionSequenceNextNode(const std::string &
void registerAggregateFunctionSequenceNextNode(AggregateFunctionFactory & factory)
{
AggregateFunctionProperties properties = { .returns_default_when_only_null = true, .is_order_dependent = false };
factory.registerFunction("sequenceNextNode", { createAggregateFunctionSequenceNextNode, properties });
factory.registerFunction("sequenceNextNode", { createAggregateFunctionSequenceNode<2 + 31>, properties });
factory.registerFunction("sequenceFirstNode", { createAggregateFunctionSequenceNode<2>, properties });
}
}

View File

@ -309,4 +309,139 @@ public:
bool allocatesMemoryInArena() const override { return true; }
};
template <typename T, typename Node, bool Descending>
class SequenceFirstNodeImpl final
: public IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<T, Node, Descending>, SequenceFirstNodeImpl<T, Node, Descending>>
{
using Data = SequenceNextNodeGeneralData<T, Node, Descending>;
static Data & data(AggregateDataPtr place) { return *reinterpret_cast<Data *>(place); }
static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast<const Data *>(place); }
DataTypePtr & data_type;
public:
SequenceFirstNodeImpl(const DataTypePtr & data_type_)
: IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<T, Node, Descending>, SequenceFirstNodeImpl<T, Node, Descending>>(
{data_type_}, {})
, data_type(this->argument_types[0])
{
}
String getName() const override { return "sequenceFirstNode"; }
DataTypePtr getReturnType() const override { return data_type; }
AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params,
const AggregateFunctionProperties &) const override
{
return std::make_shared<AggregateFunctionNullVariadic<false, false, true, true>>(nested_function, arguments, params);
}
void insert(Data & a, const Node * v, Arena * arena) const
{
++a.total_values;
a.value.push_back(v->clone(arena), arena);
}
void create(AggregateDataPtr place) const override
{
[[maybe_unused]] auto a = new (place) Data;
}
bool compare(const T lhs_timestamp, const T rhs_timestamp) const
{
return Descending ? lhs_timestamp < rhs_timestamp : lhs_timestamp > rhs_timestamp;
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
bool is_first = true;
auto & value = data(place).value;
const auto timestamp = assert_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
if (value.size() != 0)
{
if (compare(value[0]->event_time, timestamp))
value.pop_back();
else
is_first = false;
}
if (is_first)
{
Node * node = Node::allocate(*columns[1], row_num, arena);
node->event_time = timestamp;
node->events_bitset = 0x80000000;
data(place).value.push_back(node, arena);
}
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
auto & a = data(place).value;
auto & b = data(rhs).value;
if (b.empty())
return;
if (a.empty())
{
a.push_back(b[0]->clone(arena), arena);
return;
}
if (compare(a[0]->event_time, b[0]->event_time))
{
data(place).value.pop_back();
a.push_back(b[0]->clone(arena), arena);
}
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
writeBinary(data(place).sorted, buf);
auto & value = data(place).value;
writeVarUInt(value.size(), buf);
for (auto & node : value)
node->write(buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
{
readBinary(data(place).sorted, buf);
UInt64 size;
readVarUInt(size, buf);
if (unlikely(size == 0))
return;
auto & value = data(place).value;
value.resize(size, arena);
for (UInt64 i = 0; i < size; ++i)
value[i] = Node::read(buf, arena);
}
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override
{
auto & value = data(place).value;
if (value.size() > 0)
{
ColumnNullable & to_concrete = assert_cast<ColumnNullable &>(to);
value[0]->insertInto(to_concrete.getNestedColumn());
to_concrete.getNullMapData().push_back(0);
}
else
to.insertDefault();
}
bool allocatesMemoryInArena() const override { return true; }
};
}

View File

@ -132,6 +132,30 @@
(0, C) id = 11 0
(0, C->B) id = 11 0
(0, C->B->A) id = 11 0
(0) id < 10 1 A
(0) id < 10 2 A
(0) id < 10 3 A
(0) id < 10 4 A
(0) id < 10 5 A
(0) id < 10 6 A
(0) id < 10 1 A
(0) id < 10 2 A
(0) id < 10 3 A
(0) id < 10 4 A
(0) id < 10 5 A
(0) id < 10 6 A
(1) id < 10 1 D
(1) id < 10 2 C
(1) id < 10 3 B
(1) id < 10 4 C
(1) id < 10 5 C
(1) id < 10 6 C
(1) id < 10 1 D
(1) id < 10 2 C
(1) id < 10 3 B
(1) id < 10 4 C
(1) id < 10 5 C
(1) id < 10 6 C
(0, A) 1 B
(0, A) 2 B
(0, A) 3 B
@ -266,3 +290,27 @@
(0, C) id = 11 1
(0, C->B) id = 11 1
(0, C->B->A) id = 11 1
(0) id < 10 1 A
(0) id < 10 2 A
(0) id < 10 3 A
(0) id < 10 4 A
(0) id < 10 5 A
(0) id < 10 6 A
(0) id < 10 1 A
(0) id < 10 2 A
(0) id < 10 3 A
(0) id < 10 4 A
(0) id < 10 5 A
(0) id < 10 6 A
(1) id < 10 1 D
(1) id < 10 2 C
(1) id < 10 3 B
(1) id < 10 4 C
(1) id < 10 5 C
(1) id < 10 6 C
(1) id < 10 1 D
(1) id < 10 2 C
(1) id < 10 3 B
(1) id < 10 4 C
(1) id < 10 5 C
(1) id < 10 6 C

View File

@ -70,6 +70,11 @@ SELECT '(0, C) id = 11', count() FROM (SELECT id, sequenceNextNode(1)(dt, action
SELECT '(0, C->B) id = 11', count() FROM (SELECT id, sequenceNextNode(1)(dt, action, action = 'C', action ='B') AS next_node FROM test_sequenceNextNode_Nullable WHERE id = 11 GROUP BY id HAVING next_node in ('A', 'D'));
SELECT '(0, C->B->A) id = 11', count() FROM (SELECT id, sequenceNextNode(1)(dt, action, action = 'C', action = 'B', action = 'A') AS next_node FROM test_sequenceNextNode_Nullable WHERE id = 11 GROUP BY id HAVING next_node in ('D'));
SELECT '(0) id < 10', id, sequenceNextNode(0)(dt, action) AS next_node FROM test_sequenceNextNode_Nullable WHERE id < 10 GROUP BY id ORDER BY id;
SELECT '(0) id < 10', id, sequenceFirstNode(0)(dt, action) AS next_node FROM test_sequenceNextNode_Nullable WHERE id < 10 GROUP BY id ORDER BY id;
SELECT '(1) id < 10', id, sequenceNextNode(1)(dt, action) AS next_node FROM test_sequenceNextNode_Nullable WHERE id < 10 GROUP BY id ORDER BY id;
SELECT '(1) id < 10', id, sequenceFirstNode(1)(dt, action) AS next_node FROM test_sequenceNextNode_Nullable WHERE id < 10 GROUP BY id ORDER BY id;
DROP TABLE IF EXISTS test_sequenceNextNode_Nullable;
-- The same testcases for a non-null type.
@ -150,4 +155,9 @@ SELECT '(0, C) id = 11', count() FROM (SELECT id, sequenceNextNode(1)(dt, action
SELECT '(0, C->B) id = 11', count() FROM (SELECT id, sequenceNextNode(1)(dt, action, action = 'C', action ='B') AS next_node FROM test_sequenceNextNode WHERE id = 11 GROUP BY id HAVING next_node in ('A', 'D'));
SELECT '(0, C->B->A) id = 11', count() FROM (SELECT id, sequenceNextNode(1)(dt, action, action = 'C', action = 'B', action = 'A') AS next_node FROM test_sequenceNextNode WHERE id = 11 GROUP BY id HAVING next_node in ('D'));
SELECT '(0) id < 10', id, sequenceNextNode(0)(dt, action) AS next_node FROM test_sequenceNextNode WHERE id < 10 GROUP BY id ORDER BY id;
SELECT '(0) id < 10', id, sequenceFirstNode(0)(dt, action) AS next_node FROM test_sequenceNextNode WHERE id < 10 GROUP BY id ORDER BY id;
SELECT '(1) id < 10', id, sequenceNextNode(1)(dt, action) AS next_node FROM test_sequenceNextNode WHERE id < 10 GROUP BY id ORDER BY id;
SELECT '(1) id < 10', id, sequenceFirstNode(1)(dt, action) AS next_node FROM test_sequenceNextNode WHERE id < 10 GROUP BY id ORDER BY id;
DROP TABLE IF EXISTS test_sequenceNextNode;