From 2cc69893f26920aa41c5da497035b46e6ce67360 Mon Sep 17 00:00:00 2001 From: achimbab <07c00h@gmail.com> Date: Thu, 4 Feb 2021 16:15:04 +0900 Subject: [PATCH] Add sequenceFirstNode --- .../AggregateFunctionSequenceNextNode.cpp | 43 ++++-- .../AggregateFunctionSequenceNextNode.h | 135 ++++++++++++++++++ .../01656_sequence_next_node.reference | 48 +++++++ .../0_stateless/01656_sequence_next_node.sql | 10 ++ 4 files changed, 221 insertions(+), 15 deletions(-) diff --git a/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp b/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp index 57ba87c922f..af90c80de61 100644 --- a/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp +++ b/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp @@ -23,15 +23,27 @@ namespace { template -inline AggregateFunctionPtr createAggregateFunctionSequenceNextNodeImpl(const DataTypePtr data_type, const DataTypes & argument_types, bool descending_order) +inline AggregateFunctionPtr createAggregateFunctionSequenceNodeImpl(const DataTypePtr data_type, const DataTypes & argument_types, bool descending_order) { - if (descending_order) - return std::make_shared>(data_type, argument_types); + if (argument_types.size() == 2) + { + // If the number of arguments of sequenceNextNode is 2, the sequenceNextNode acts as sequenceFirstNode. + if (descending_order) + return std::make_shared>(data_type); + else + return std::make_shared>(data_type); + } else - return std::make_shared>(data_type, argument_types); + { + if (descending_order) + return std::make_shared>(data_type, argument_types); + else + return std::make_shared>(data_type, argument_types); + } } -AggregateFunctionPtr createAggregateFunctionSequenceNextNode(const std::string & name, const DataTypes & argument_types, const Array & parameters) +template +AggregateFunctionPtr createAggregateFunctionSequenceNode(const std::string & name, const DataTypes & argument_types, const Array & parameters) { bool descending_order = false; @@ -47,9 +59,9 @@ AggregateFunctionPtr createAggregateFunctionSequenceNextNode(const std::string & throw Exception("Incorrect number of parameters for aggregate function " + name + ", should be 1", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - if (argument_types.size() < 3) - throw Exception("Aggregate function " + name + " requires at least three arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - else if (argument_types.size() > 2 + 31) + if (argument_types.size() < 2) + throw Exception("Aggregate function " + name + " requires at least two arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + else if (argument_types.size() > MaxArgs) throw Exception("Aggregate function " + name + " requires at most 34(timestamp, value_column, 31 events) arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); for (const auto i : ext::range(2, argument_types.size())) @@ -73,17 +85,17 @@ AggregateFunctionPtr createAggregateFunctionSequenceNextNode(const std::string & WhichDataType timestamp_type(argument_types[0].get()); if (timestamp_type.idx == TypeIndex::UInt8) - return createAggregateFunctionSequenceNextNodeImpl(data_type, argument_types, descending_order); + return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, descending_order); if (timestamp_type.idx == TypeIndex::UInt16) - return createAggregateFunctionSequenceNextNodeImpl(data_type, argument_types, descending_order); + return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, descending_order); if (timestamp_type.idx == TypeIndex::UInt32) - return createAggregateFunctionSequenceNextNodeImpl(data_type, argument_types, descending_order); + return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, descending_order); if (timestamp_type.idx == TypeIndex::UInt64) - return createAggregateFunctionSequenceNextNodeImpl(data_type, argument_types, descending_order); + return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, descending_order); if (timestamp_type.isDate()) - return createAggregateFunctionSequenceNextNodeImpl(data_type, argument_types, descending_order); + return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, descending_order); if (timestamp_type.isDateTime()) - return createAggregateFunctionSequenceNextNodeImpl(data_type, argument_types, descending_order); + return createAggregateFunctionSequenceNodeImpl(data_type, argument_types, descending_order); throw Exception{"Illegal type " + argument_types.front().get()->getName() + " of first argument of aggregate function " + name + ", must be Unsigned Number, Date, DateTime", @@ -95,7 +107,8 @@ AggregateFunctionPtr createAggregateFunctionSequenceNextNode(const std::string & void registerAggregateFunctionSequenceNextNode(AggregateFunctionFactory & factory) { AggregateFunctionProperties properties = { .returns_default_when_only_null = true, .is_order_dependent = false }; - factory.registerFunction("sequenceNextNode", { createAggregateFunctionSequenceNextNode, properties }); + factory.registerFunction("sequenceNextNode", { createAggregateFunctionSequenceNode<2 + 31>, properties }); + factory.registerFunction("sequenceFirstNode", { createAggregateFunctionSequenceNode<2>, properties }); } } diff --git a/src/AggregateFunctions/AggregateFunctionSequenceNextNode.h b/src/AggregateFunctions/AggregateFunctionSequenceNextNode.h index 888149c77da..ffcc02b805a 100644 --- a/src/AggregateFunctions/AggregateFunctionSequenceNextNode.h +++ b/src/AggregateFunctions/AggregateFunctionSequenceNextNode.h @@ -309,4 +309,139 @@ public: bool allocatesMemoryInArena() const override { return true; } }; +template +class SequenceFirstNodeImpl final + : public IAggregateFunctionDataHelper, SequenceFirstNodeImpl> +{ + using Data = SequenceNextNodeGeneralData; + static Data & data(AggregateDataPtr place) { return *reinterpret_cast(place); } + static const Data & data(ConstAggregateDataPtr place) { return *reinterpret_cast(place); } + + DataTypePtr & data_type; + +public: + SequenceFirstNodeImpl(const DataTypePtr & data_type_) + : IAggregateFunctionDataHelper, SequenceFirstNodeImpl>( + {data_type_}, {}) + , data_type(this->argument_types[0]) + { + } + + String getName() const override { return "sequenceFirstNode"; } + + DataTypePtr getReturnType() const override { return data_type; } + + AggregateFunctionPtr getOwnNullAdapter( + const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params, + const AggregateFunctionProperties &) const override + { + return std::make_shared>(nested_function, arguments, params); + } + + void insert(Data & a, const Node * v, Arena * arena) const + { + ++a.total_values; + a.value.push_back(v->clone(arena), arena); + } + + void create(AggregateDataPtr place) const override + { + [[maybe_unused]] auto a = new (place) Data; + } + + bool compare(const T lhs_timestamp, const T rhs_timestamp) const + { + return Descending ? lhs_timestamp < rhs_timestamp : lhs_timestamp > rhs_timestamp; + } + + void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + bool is_first = true; + auto & value = data(place).value; + const auto timestamp = assert_cast *>(columns[0])->getData()[row_num]; + + if (value.size() != 0) + { + if (compare(value[0]->event_time, timestamp)) + value.pop_back(); + else + is_first = false; + } + + + if (is_first) + { + Node * node = Node::allocate(*columns[1], row_num, arena); + node->event_time = timestamp; + node->events_bitset = 0x80000000; + + data(place).value.push_back(node, arena); + } + } + + void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override + { + auto & a = data(place).value; + auto & b = data(rhs).value; + + if (b.empty()) + return; + + if (a.empty()) + { + a.push_back(b[0]->clone(arena), arena); + return; + } + + if (compare(a[0]->event_time, b[0]->event_time)) + { + data(place).value.pop_back(); + a.push_back(b[0]->clone(arena), arena); + } + } + + void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override + { + writeBinary(data(place).sorted, buf); + + auto & value = data(place).value; + writeVarUInt(value.size(), buf); + for (auto & node : value) + node->write(buf); + } + + void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override + { + readBinary(data(place).sorted, buf); + + UInt64 size; + readVarUInt(size, buf); + + if (unlikely(size == 0)) + return; + + auto & value = data(place).value; + + value.resize(size, arena); + for (UInt64 i = 0; i < size; ++i) + value[i] = Node::read(buf, arena); + } + + void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override + { + auto & value = data(place).value; + + if (value.size() > 0) + { + ColumnNullable & to_concrete = assert_cast(to); + value[0]->insertInto(to_concrete.getNestedColumn()); + to_concrete.getNullMapData().push_back(0); + } + else + to.insertDefault(); + } + + bool allocatesMemoryInArena() const override { return true; } +}; + } diff --git a/tests/queries/0_stateless/01656_sequence_next_node.reference b/tests/queries/0_stateless/01656_sequence_next_node.reference index 3e8da4bbd48..50755232cb9 100644 --- a/tests/queries/0_stateless/01656_sequence_next_node.reference +++ b/tests/queries/0_stateless/01656_sequence_next_node.reference @@ -132,6 +132,30 @@ (0, C) id = 11 0 (0, C->B) id = 11 0 (0, C->B->A) id = 11 0 +(0) id < 10 1 A +(0) id < 10 2 A +(0) id < 10 3 A +(0) id < 10 4 A +(0) id < 10 5 A +(0) id < 10 6 A +(0) id < 10 1 A +(0) id < 10 2 A +(0) id < 10 3 A +(0) id < 10 4 A +(0) id < 10 5 A +(0) id < 10 6 A +(1) id < 10 1 D +(1) id < 10 2 C +(1) id < 10 3 B +(1) id < 10 4 C +(1) id < 10 5 C +(1) id < 10 6 C +(1) id < 10 1 D +(1) id < 10 2 C +(1) id < 10 3 B +(1) id < 10 4 C +(1) id < 10 5 C +(1) id < 10 6 C (0, A) 1 B (0, A) 2 B (0, A) 3 B @@ -266,3 +290,27 @@ (0, C) id = 11 1 (0, C->B) id = 11 1 (0, C->B->A) id = 11 1 +(0) id < 10 1 A +(0) id < 10 2 A +(0) id < 10 3 A +(0) id < 10 4 A +(0) id < 10 5 A +(0) id < 10 6 A +(0) id < 10 1 A +(0) id < 10 2 A +(0) id < 10 3 A +(0) id < 10 4 A +(0) id < 10 5 A +(0) id < 10 6 A +(1) id < 10 1 D +(1) id < 10 2 C +(1) id < 10 3 B +(1) id < 10 4 C +(1) id < 10 5 C +(1) id < 10 6 C +(1) id < 10 1 D +(1) id < 10 2 C +(1) id < 10 3 B +(1) id < 10 4 C +(1) id < 10 5 C +(1) id < 10 6 C diff --git a/tests/queries/0_stateless/01656_sequence_next_node.sql b/tests/queries/0_stateless/01656_sequence_next_node.sql index 31c224fd2a4..9af59d5c8e2 100644 --- a/tests/queries/0_stateless/01656_sequence_next_node.sql +++ b/tests/queries/0_stateless/01656_sequence_next_node.sql @@ -70,6 +70,11 @@ SELECT '(0, C) id = 11', count() FROM (SELECT id, sequenceNextNode(1)(dt, action SELECT '(0, C->B) id = 11', count() FROM (SELECT id, sequenceNextNode(1)(dt, action, action = 'C', action ='B') AS next_node FROM test_sequenceNextNode_Nullable WHERE id = 11 GROUP BY id HAVING next_node in ('A', 'D')); SELECT '(0, C->B->A) id = 11', count() FROM (SELECT id, sequenceNextNode(1)(dt, action, action = 'C', action = 'B', action = 'A') AS next_node FROM test_sequenceNextNode_Nullable WHERE id = 11 GROUP BY id HAVING next_node in ('D')); +SELECT '(0) id < 10', id, sequenceNextNode(0)(dt, action) AS next_node FROM test_sequenceNextNode_Nullable WHERE id < 10 GROUP BY id ORDER BY id; +SELECT '(0) id < 10', id, sequenceFirstNode(0)(dt, action) AS next_node FROM test_sequenceNextNode_Nullable WHERE id < 10 GROUP BY id ORDER BY id; +SELECT '(1) id < 10', id, sequenceNextNode(1)(dt, action) AS next_node FROM test_sequenceNextNode_Nullable WHERE id < 10 GROUP BY id ORDER BY id; +SELECT '(1) id < 10', id, sequenceFirstNode(1)(dt, action) AS next_node FROM test_sequenceNextNode_Nullable WHERE id < 10 GROUP BY id ORDER BY id; + DROP TABLE IF EXISTS test_sequenceNextNode_Nullable; -- The same testcases for a non-null type. @@ -150,4 +155,9 @@ SELECT '(0, C) id = 11', count() FROM (SELECT id, sequenceNextNode(1)(dt, action SELECT '(0, C->B) id = 11', count() FROM (SELECT id, sequenceNextNode(1)(dt, action, action = 'C', action ='B') AS next_node FROM test_sequenceNextNode WHERE id = 11 GROUP BY id HAVING next_node in ('A', 'D')); SELECT '(0, C->B->A) id = 11', count() FROM (SELECT id, sequenceNextNode(1)(dt, action, action = 'C', action = 'B', action = 'A') AS next_node FROM test_sequenceNextNode WHERE id = 11 GROUP BY id HAVING next_node in ('D')); +SELECT '(0) id < 10', id, sequenceNextNode(0)(dt, action) AS next_node FROM test_sequenceNextNode WHERE id < 10 GROUP BY id ORDER BY id; +SELECT '(0) id < 10', id, sequenceFirstNode(0)(dt, action) AS next_node FROM test_sequenceNextNode WHERE id < 10 GROUP BY id ORDER BY id; +SELECT '(1) id < 10', id, sequenceNextNode(1)(dt, action) AS next_node FROM test_sequenceNextNode WHERE id < 10 GROUP BY id ORDER BY id; +SELECT '(1) id < 10', id, sequenceFirstNode(1)(dt, action) AS next_node FROM test_sequenceNextNode WHERE id < 10 GROUP BY id ORDER BY id; + DROP TABLE IF EXISTS test_sequenceNextNode;