Refactoring

This commit is contained in:
achimbab 2021-01-28 12:00:37 +09:00
parent fce1ca255d
commit e93caefd62
4 changed files with 71 additions and 47 deletions

View File

@ -157,15 +157,8 @@ public:
ColumnNullable & to_concrete = assert_cast<ColumnNullable &>(to);
if (getFlag(place))
{
if (unlikely(nested_function->doesInsertResultNeedNullableColumn()))
{
nested_function->insertResultInto(nestedPlace(place), to_concrete, arena);
}
else
{
nested_function->insertResultInto(nestedPlace(place), to_concrete.getNestedColumn(), arena);
to_concrete.getNullMapData().push_back(0);
}
nested_function->insertResultInto(nestedPlace(place), to_concrete.getNestedColumn(), arena);
to_concrete.getNullMapData().push_back(0);
}
else
{
@ -235,7 +228,7 @@ public:
};
template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped>
template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped, bool insertion_requires_nullable_column = false>
class AggregateFunctionNullVariadic final
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>
@ -283,6 +276,35 @@ public:
this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
}
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * arena) const override
{
if constexpr (result_is_nullable)
{
ColumnNullable & to_concrete = assert_cast<ColumnNullable &>(to);
if (this->getFlag(place))
{
if constexpr (insertion_requires_nullable_column)
{
this->nested_function->insertResultInto(this->nestedPlace(place), to_concrete, arena);
}
else
{
this->nested_function->insertResultInto(this->nestedPlace(place), to_concrete.getNestedColumn(), arena);
to_concrete.getNullMapData().push_back(0);
}
}
else
{
to_concrete.insertDefault();
}
}
else
{
this->nested_function->insertResultInto(this->nestedPlace(place), to, arena);
}
}
private:
enum { MAX_ARGS = 8 };
size_t number_of_arguments = 0;

View File

@ -49,8 +49,8 @@ AggregateFunctionPtr createAggregateFunctionSequenceNextNode(const std::string &
if (argument_types.size() < 3)
throw Exception("Aggregate function " + name + " requires at least three arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
else if (argument_types.size() > 2 + 64)
throw Exception("Aggregate function " + name + " requires at most 66(timestamp, value_column, 64 events) arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
else if (argument_types.size() > 2 + 32)
throw Exception("Aggregate function " + name + " requires at most 34(timestamp, value_column, 32 events) arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (const auto i : ext::range(2, argument_types.size()))
{

View File

@ -11,7 +11,6 @@
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeDateTime.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnNullable.h>
@ -20,6 +19,7 @@
#include <Common/assert_cast.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/AggregateFunctionNull.h>
#include <type_traits>
@ -33,28 +33,24 @@ struct NodeBase
UInt64 size; // size of payload
DataTypeDateTime::FieldType event_time;
UInt64 events_bitmap;
UInt32 events_bitset; // UInt32 for combiniant comparesons between bitsets (< operator on bitsets).
/// Returns pointer to actual payload
char * data() { return reinterpret_cast<char *>(this) + sizeof(Node); }
const char * data() const { return reinterpret_cast<const char *>(this) + sizeof(Node); }
/// Clones existing node (does not modify next field)
Node * clone(Arena * arena) const
{
return reinterpret_cast<Node *>(
const_cast<char *>(arena->alignedInsert(reinterpret_cast<const char *>(this), sizeof(Node) + size, alignof(Node))));
}
/// Write node to buffer
void write(WriteBuffer & buf) const
{
writeVarUInt(size, buf);
buf.write(data(), size);
}
/// Reads and allocates node from ReadBuffer's data (doesn't set next)
static Node * read(ReadBuffer & buf, Arena * arena)
{
UInt64 size;
@ -71,7 +67,6 @@ struct NodeString : public NodeBase<NodeString>
{
using Node = NodeString;
/// Create node from string
static Node * allocate(const IColumn & column, size_t row_num, Arena * arena)
{
StringRef string = assert_cast<const ColumnString &>(column).getDataAt(row_num);
@ -92,7 +87,6 @@ struct NodeString : public NodeBase<NodeString>
template <typename T, typename Node, bool Descending>
struct SequenceNextNodeGeneralData
{
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
using Allocator = MixedAlignedArenaAllocator<alignof(Node *), 4096>;
using Array = PODArray<Node *, 32, Allocator>;
@ -103,12 +97,12 @@ struct SequenceNextNodeGeneralData
{
bool operator()(const Node * lhs, const Node * rhs) const
{
if (Descending)
if constexpr (Descending)
return lhs->event_time == rhs->event_time ?
lhs->events_bitmap < rhs->events_bitmap: lhs->event_time > rhs->event_time;
lhs->events_bitset < rhs->events_bitset: lhs->event_time > rhs->event_time;
else
return lhs->event_time == rhs->event_time ?
lhs->events_bitmap < rhs->events_bitmap : lhs->event_time < rhs->event_time;
lhs->events_bitset < rhs->events_bitset : lhs->event_time < rhs->event_time;
}
};
@ -122,7 +116,6 @@ struct SequenceNextNodeGeneralData
}
};
/// Implementation of groupArray for String or any ComplexObject via Array
template <typename T, typename Node, bool Descending>
class SequenceNextNodeImpl final
: public IAggregateFunctionDataHelper<SequenceNextNodeGeneralData<T, Node, Descending>, SequenceNextNodeImpl<T, Node, Descending>>
@ -149,6 +142,18 @@ public:
DataTypePtr getReturnType() const override { return data_type; }
AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params,
const AggregateFunctionProperties & /*properties*/) const override
{
// This aggregate function sets insertion_requires_nullable_column on.
// Even though some values are mapped to aggregating key, it could return nulls for the below case.
// aggregated events: [A -> B -> C]
// events to find: [C -> D]
// [C -> D] is not matched to 'A -> B -> C' so that it returns null.
return std::make_shared<AggregateFunctionNullVariadic<false, false, false, true>>(nested_function, arguments, params);
}
void insert(Data & a, const Node * v, Arena * arena) const
{
++a.total_values;
@ -166,44 +171,42 @@ public:
const auto timestamp = assert_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
UInt64 events_bitmap = 0;
UInt32 events_bitset = 0;
for (UInt8 i = 0; i < events_size; ++i)
if (assert_cast<const ColumnVector<UInt8> *>(columns[2 + i])->getData()[row_num])
events_bitmap += (1 << i);
events_bitset += (1 << i);
node->event_time = timestamp;
node->events_bitmap = events_bitmap;
node->events_bitset = events_bitset;
data(place).value.push_back(node, arena);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if (data(rhs).value.empty()) /// rhs state is empty
if (data(rhs).value.empty())
return;
UInt64 new_elems;
if (data(place).value.size() >= max_elems)
return;
new_elems = std::min(data(rhs).value.size(), static_cast<size_t>(max_elems) - data(place).value.size());
auto & a = data(place).value;
const auto size = a.size();
auto & b = data(rhs).value;
const auto a_size = a.size();
const UInt64 new_elems = std::min(data(rhs).value.size(), static_cast<size_t>(max_elems) - data(place).value.size());
for (UInt64 i = 0; i < new_elems; ++i)
a.push_back(b[i]->clone(arena), arena);
/// either sort whole container or do so partially merging ranges afterwards
using Comparator = typename SequenceNextNodeGeneralData<T, Node, Descending>::Comparator;
/// either sort whole container or do so partially merging ranges afterwards
if (!data(place).sorted && !data(rhs).sorted)
std::stable_sort(std::begin(a), std::end(a), Comparator{});
else
{
const auto begin = std::begin(a);
const auto middle = std::next(begin, size);
const auto middle = std::next(begin, a_size);
const auto end = std::end(a);
if (!data(place).sorted)
@ -242,34 +245,36 @@ public:
value[i] = Node::read(buf, arena);
}
inline UInt64 getSkipCount(const Data & data, const UInt64 i, const UInt64 j) const
inline UInt32 calculateJump(const Data & data, const UInt32 i, const UInt32 j) const
{
UInt64 k = 0;
UInt32 k = 0;
for (; k < events_size - j; ++k)
if (data.value[i - j]->events_bitmap & (1 << (events_size - 1 - j - k)))
if (data.value[i - j]->events_bitset & (1 << (events_size - 1 - j - k)))
return k;
return k;
}
UInt64 getNextNodeIndex(Data & data) const
// This method returns an index of next node that matched the events.
// It is one as referring Boyer-Moore-Algorithm.
UInt32 getNextNodeIndex(Data & data) const
{
if (data.value.size() <= events_size)
return 0;
data.sort();
UInt64 i = events_size - 1;
UInt32 i = events_size - 1;
while (i < data.value.size())
{
UInt64 j = 0;
UInt32 j = 0;
for (; j < events_size; ++j)
if (!(data.value[i - j]->events_bitmap & (1 << (events_size - 1 - j))))
if (!(data.value[i - j]->events_bitset & (1 << (events_size - 1 - j))))
break;
if (j == events_size)
return i + 1;
i += getSkipCount(data, i, j);
i += calculateJump(data, i, j);
}
return 0;
@ -279,7 +284,7 @@ public:
{
auto & value = data(place).value;
UInt64 event_idx = getNextNodeIndex(this->data(place));
UInt32 event_idx = getNextNodeIndex(this->data(place));
if (event_idx != 0 && event_idx < value.size())
{
ColumnNullable & to_concrete = assert_cast<ColumnNullable &>(to);
@ -290,8 +295,6 @@ public:
to.insertDefault();
}
bool doesInsertResultNeedNullableColumn() const override { return true; }
bool allocatesMemoryInArena() const override { return true; }
};

View File

@ -112,7 +112,6 @@ public:
/// in `runningAccumulate`, or when calculating an aggregate function as a
/// window function.
virtual void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * arena) const = 0;
virtual bool doesInsertResultNeedNullableColumn() const { return false; }
/// Used for machine learning methods. Predict result from trained model.
/// Will insert result into `to` column for rows in range [offset, offset + limit).