diff --git a/src/AggregateFunctions/AggregateFunctionNull.h b/src/AggregateFunctions/AggregateFunctionNull.h index 3bfcacf7d7b..4dc3c580fd7 100644 --- a/src/AggregateFunctions/AggregateFunctionNull.h +++ b/src/AggregateFunctions/AggregateFunctionNull.h @@ -157,15 +157,8 @@ public: ColumnNullable & to_concrete = assert_cast(to); if (getFlag(place)) { - if (unlikely(nested_function->doesInsertResultNeedNullableColumn())) - { - nested_function->insertResultInto(nestedPlace(place), to_concrete, arena); - } - else - { - nested_function->insertResultInto(nestedPlace(place), to_concrete.getNestedColumn(), arena); - to_concrete.getNullMapData().push_back(0); - } + nested_function->insertResultInto(nestedPlace(place), to_concrete.getNestedColumn(), arena); + to_concrete.getNullMapData().push_back(0); } else { @@ -235,7 +228,7 @@ public: }; -template +template class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase> @@ -283,6 +276,35 @@ public: this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena); } + void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * arena) const override + { + if constexpr (result_is_nullable) + { + ColumnNullable & to_concrete = assert_cast(to); + if (this->getFlag(place)) + { + if constexpr (insertion_requires_nullable_column) + { + this->nested_function->insertResultInto(this->nestedPlace(place), to_concrete, arena); + } + else + { + this->nested_function->insertResultInto(this->nestedPlace(place), to_concrete.getNestedColumn(), arena); + to_concrete.getNullMapData().push_back(0); + } + } + else + { + to_concrete.insertDefault(); + } + } + else + { + this->nested_function->insertResultInto(this->nestedPlace(place), to, arena); + } + } + + private: enum { MAX_ARGS = 8 }; size_t number_of_arguments = 0; diff --git a/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp b/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp index 66f24ec8cbf..b185859e00e 100644 --- a/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp +++ b/src/AggregateFunctions/AggregateFunctionSequenceNextNode.cpp @@ -49,8 +49,8 @@ AggregateFunctionPtr createAggregateFunctionSequenceNextNode(const std::string & if (argument_types.size() < 3) throw Exception("Aggregate function " + name + " requires at least three arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - else if (argument_types.size() > 2 + 64) - throw Exception("Aggregate function " + name + " requires at most 66(timestamp, value_column, 64 events) arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + else if (argument_types.size() > 2 + 32) + throw Exception("Aggregate function " + name + " requires at most 34(timestamp, value_column, 32 events) arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); for (const auto i : ext::range(2, argument_types.size())) { diff --git a/src/AggregateFunctions/AggregateFunctionSequenceNextNode.h b/src/AggregateFunctions/AggregateFunctionSequenceNextNode.h index a455e16e267..9bdd54e8b4b 100644 --- a/src/AggregateFunctions/AggregateFunctionSequenceNextNode.h +++ b/src/AggregateFunctions/AggregateFunctionSequenceNextNode.h @@ -11,7 +11,6 @@ #include #include -#include #include #include #include @@ -20,6 +19,7 @@ #include #include +#include #include @@ -33,28 +33,24 @@ struct NodeBase UInt64 size; // size of payload DataTypeDateTime::FieldType event_time; - UInt64 events_bitmap; + UInt32 events_bitset; // UInt32 for combiniant comparesons between bitsets (< operator on bitsets). - /// Returns pointer to actual payload char * data() { return reinterpret_cast(this) + sizeof(Node); } const char * data() const { return reinterpret_cast(this) + sizeof(Node); } - /// Clones existing node (does not modify next field) Node * clone(Arena * arena) const { return reinterpret_cast( const_cast(arena->alignedInsert(reinterpret_cast(this), sizeof(Node) + size, alignof(Node)))); } - /// Write node to buffer void write(WriteBuffer & buf) const { writeVarUInt(size, buf); buf.write(data(), size); } - /// Reads and allocates node from ReadBuffer's data (doesn't set next) static Node * read(ReadBuffer & buf, Arena * arena) { UInt64 size; @@ -71,7 +67,6 @@ struct NodeString : public NodeBase { using Node = NodeString; - /// Create node from string static Node * allocate(const IColumn & column, size_t row_num, Arena * arena) { StringRef string = assert_cast(column).getDataAt(row_num); @@ -92,7 +87,6 @@ struct NodeString : public NodeBase template struct SequenceNextNodeGeneralData { - // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena using Allocator = MixedAlignedArenaAllocator; using Array = PODArray; @@ -103,12 +97,12 @@ struct SequenceNextNodeGeneralData { bool operator()(const Node * lhs, const Node * rhs) const { - if (Descending) + if constexpr (Descending) return lhs->event_time == rhs->event_time ? - lhs->events_bitmap < rhs->events_bitmap: lhs->event_time > rhs->event_time; + lhs->events_bitset < rhs->events_bitset: lhs->event_time > rhs->event_time; else return lhs->event_time == rhs->event_time ? - lhs->events_bitmap < rhs->events_bitmap : lhs->event_time < rhs->event_time; + lhs->events_bitset < rhs->events_bitset : lhs->event_time < rhs->event_time; } }; @@ -122,7 +116,6 @@ struct SequenceNextNodeGeneralData } }; -/// Implementation of groupArray for String or any ComplexObject via Array template class SequenceNextNodeImpl final : public IAggregateFunctionDataHelper, SequenceNextNodeImpl> @@ -149,6 +142,18 @@ public: DataTypePtr getReturnType() const override { return data_type; } + AggregateFunctionPtr getOwnNullAdapter( + const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params, + const AggregateFunctionProperties & /*properties*/) const override + { + // This aggregate function sets insertion_requires_nullable_column on. + // Even though some values are mapped to aggregating key, it could return nulls for the below case. + // aggregated events: [A -> B -> C] + // events to find: [C -> D] + // [C -> D] is not matched to 'A -> B -> C' so that it returns null. + return std::make_shared>(nested_function, arguments, params); + } + void insert(Data & a, const Node * v, Arena * arena) const { ++a.total_values; @@ -166,44 +171,42 @@ public: const auto timestamp = assert_cast *>(columns[0])->getData()[row_num]; - UInt64 events_bitmap = 0; + UInt32 events_bitset = 0; for (UInt8 i = 0; i < events_size; ++i) if (assert_cast *>(columns[2 + i])->getData()[row_num]) - events_bitmap += (1 << i); + events_bitset += (1 << i); node->event_time = timestamp; - node->events_bitmap = events_bitmap; + node->events_bitset = events_bitset; data(place).value.push_back(node, arena); } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override { - if (data(rhs).value.empty()) /// rhs state is empty + if (data(rhs).value.empty()) return; - UInt64 new_elems; if (data(place).value.size() >= max_elems) return; - new_elems = std::min(data(rhs).value.size(), static_cast(max_elems) - data(place).value.size()); - auto & a = data(place).value; - const auto size = a.size(); - auto & b = data(rhs).value; + const auto a_size = a.size(); + + const UInt64 new_elems = std::min(data(rhs).value.size(), static_cast(max_elems) - data(place).value.size()); for (UInt64 i = 0; i < new_elems; ++i) a.push_back(b[i]->clone(arena), arena); + /// either sort whole container or do so partially merging ranges afterwards using Comparator = typename SequenceNextNodeGeneralData::Comparator; - /// either sort whole container or do so partially merging ranges afterwards if (!data(place).sorted && !data(rhs).sorted) std::stable_sort(std::begin(a), std::end(a), Comparator{}); else { const auto begin = std::begin(a); - const auto middle = std::next(begin, size); + const auto middle = std::next(begin, a_size); const auto end = std::end(a); if (!data(place).sorted) @@ -242,34 +245,36 @@ public: value[i] = Node::read(buf, arena); } - inline UInt64 getSkipCount(const Data & data, const UInt64 i, const UInt64 j) const + inline UInt32 calculateJump(const Data & data, const UInt32 i, const UInt32 j) const { - UInt64 k = 0; + UInt32 k = 0; for (; k < events_size - j; ++k) - if (data.value[i - j]->events_bitmap & (1 << (events_size - 1 - j - k))) + if (data.value[i - j]->events_bitset & (1 << (events_size - 1 - j - k))) return k; return k; } - UInt64 getNextNodeIndex(Data & data) const + // This method returns an index of next node that matched the events. + // It is one as referring Boyer-Moore-Algorithm. + UInt32 getNextNodeIndex(Data & data) const { if (data.value.size() <= events_size) return 0; data.sort(); - UInt64 i = events_size - 1; + UInt32 i = events_size - 1; while (i < data.value.size()) { - UInt64 j = 0; + UInt32 j = 0; for (; j < events_size; ++j) - if (!(data.value[i - j]->events_bitmap & (1 << (events_size - 1 - j)))) + if (!(data.value[i - j]->events_bitset & (1 << (events_size - 1 - j)))) break; if (j == events_size) return i + 1; - i += getSkipCount(data, i, j); + i += calculateJump(data, i, j); } return 0; @@ -279,7 +284,7 @@ public: { auto & value = data(place).value; - UInt64 event_idx = getNextNodeIndex(this->data(place)); + UInt32 event_idx = getNextNodeIndex(this->data(place)); if (event_idx != 0 && event_idx < value.size()) { ColumnNullable & to_concrete = assert_cast(to); @@ -290,8 +295,6 @@ public: to.insertDefault(); } - bool doesInsertResultNeedNullableColumn() const override { return true; } - bool allocatesMemoryInArena() const override { return true; } }; diff --git a/src/AggregateFunctions/IAggregateFunction.h b/src/AggregateFunctions/IAggregateFunction.h index d9570fa5f8b..a9fe26688d7 100644 --- a/src/AggregateFunctions/IAggregateFunction.h +++ b/src/AggregateFunctions/IAggregateFunction.h @@ -112,7 +112,6 @@ public: /// in `runningAccumulate`, or when calculating an aggregate function as a /// window function. virtual void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * arena) const = 0; - virtual bool doesInsertResultNeedNullableColumn() const { return false; } /// Used for machine learning methods. Predict result from trained model. /// Will insert result into `to` column for rows in range [offset, offset + limit).