window function rank() and friends

This commit is contained in:
Alexander Kuzmenkov 2021-02-11 16:29:30 +03:00
parent 363007b964
commit 525400bc41
6 changed files with 259 additions and 14 deletions

View File

@ -26,6 +26,7 @@ class ReadBuffer;
class WriteBuffer;
class IColumn;
class IDataType;
class IWindowFunction;
using DataTypePtr = std::shared_ptr<const IDataType>;
using DataTypes = std::vector<DataTypePtr>;
@ -215,6 +216,20 @@ public:
const DataTypes & getArgumentTypes() const { return argument_types; }
const Array & getParameters() const { return parameters; }
// Any aggregate function can be calculated over a window, but there are some
// window functions such as rank() that require a different interface, e.g.
// because they don't respect the window frame, or need to be notified when
// a new peer group starts. They pretend to be normal aggregate functions,
// but will fail if you actually try to use them in Aggregator. The
// WindowTransform recognizes these functions and handles them differently.
// We could have a separate factory for window functions, and make all
// aggregate functions implement IWindowFunction interface and so on. This
// would be more logically correct, but more complex. We only have a handful
// of true window functions, so this hack-ish interface suffices.
virtual IWindowFunction * asWindowFunction() { return nullptr; }
virtual const IWindowFunction * asWindowFunction() const
{ return const_cast<IAggregateFunction *>(this)->asWindowFunction(); }
protected:
DataTypes argument_types;
Array parameters;

View File

@ -58,6 +58,8 @@ void registerAggregateFunctionCombinatorOrFill(AggregateFunctionCombinatorFactor
void registerAggregateFunctionCombinatorResample(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorDistinct(AggregateFunctionCombinatorFactory &);
void registerWindowFunctions(AggregateFunctionFactory & factory);
void registerAggregateFunctions()
{
@ -103,6 +105,8 @@ void registerAggregateFunctions()
registerAggregateFunctionMannWhitney(factory);
registerAggregateFunctionWelchTTest(factory);
registerAggregateFunctionStudentTTest(factory);
registerWindowFunctions(factory);
}
{

View File

@ -4,6 +4,9 @@
#include <Common/Arena.h>
#include <DataTypes/DataTypesNumber.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
namespace DB
{
@ -12,6 +15,18 @@ namespace ErrorCodes
extern const int NOT_IMPLEMENTED;
}
// Interface for true window functions. It's not much of an interface, they just
// accept the guts of WindowTransform and do 'something'. Given a small number of
// true window functions, and the fact that the WindowTransform internals are
// pretty much well defined in domain terms (e.g. frame boundaries), this is
// somewhat acceptable.
class IWindowFunction {
public:
virtual ~IWindowFunction() {}
virtual void windowInsertResultInto(IColumn & to, const WindowTransform * transform) = 0;
};
// Compares ORDER BY column values at given rows to find the boundaries of frame:
// [compared] with [reference] +/- offset. Return value is -1/0/+1, like in
// sorting predicates -- -1 means [compared] is less than [reference] +/- offset.
@ -142,9 +157,14 @@ WindowTransform::WindowTransform(const Block & input_header_,
input_header.getPositionByName(argument_name));
}
workspace.aggregate_function_state.reset(aggregate_function->sizeOfData(),
aggregate_function->alignOfData());
aggregate_function->create(workspace.aggregate_function_state.data());
workspace.window_function_impl = aggregate_function->asWindowFunction();
if (!workspace.window_function_impl)
{
workspace.aggregate_function_state.reset(
aggregate_function->sizeOfData(),
aggregate_function->alignOfData());
aggregate_function->create(workspace.aggregate_function_state.data());
}
workspaces.push_back(std::move(workspace));
}
@ -183,8 +203,11 @@ WindowTransform::~WindowTransform()
// Some states may be not created yet if the creation failed.
for (auto & ws : workspaces)
{
ws.window_function.aggregate_function->destroy(
ws.aggregate_function_state.data());
if (!ws.window_function_impl)
{
ws.window_function.aggregate_function->destroy(
ws.aggregate_function_state.data());
}
}
}
@ -756,6 +779,12 @@ void WindowTransform::updateAggregationState()
for (auto & ws : workspaces)
{
if (ws.window_function_impl)
{
// No need to do anything for true window functions.
continue;
}
const auto * a = ws.window_function.aggregate_function.get();
auto * buf = ws.aggregate_function_state.data();
@ -798,10 +827,10 @@ void WindowTransform::updateAggregationState()
// For now, add the values one by one.
auto * columns = ws.argument_columns.data();
// Removing arena.get() from the loop makes it faster somehow...
auto * arena_ = arena.get();
auto * arena_ptr = arena.get();
for (auto row = first_row; row < past_the_end_row; ++row)
{
a->add(buf, columns, row, arena_);
a->add(buf, columns, row, arena_ptr);
}
}
}
@ -819,14 +848,21 @@ void WindowTransform::writeOutCurrentRow()
for (size_t wi = 0; wi < workspaces.size(); ++wi)
{
auto & ws = workspaces[wi];
const auto & f = ws.window_function;
const auto * a = f.aggregate_function.get();
auto * buf = ws.aggregate_function_state.data();
IColumn * result_column = block.output_columns[wi].get();
// FIXME does it also allocate the result on the arena?
// We'll have to pass it out with blocks then...
a->insertResultInto(buf, *result_column, arena.get());
if (ws.window_function_impl)
{
ws.window_function_impl->windowInsertResultInto(*result_column, this);
}
else
{
const auto & f = ws.window_function;
const auto * a = f.aggregate_function.get();
auto * buf = ws.aggregate_function_state.data();
// FIXME does it also allocate the result on the arena?
// We'll have to pass it out with blocks then...
a->insertResultInto(buf, *result_column, arena.get());
}
}
}
@ -893,6 +929,8 @@ void WindowTransform::appendChunk(Chunk & chunk)
if (!arePeers(peer_group_start, current_row))
{
peer_group_start = current_row;
peer_group_start_row_number = current_row_number;
++peer_group_number;
}
// Advance the frame start.
@ -950,6 +988,7 @@ void WindowTransform::appendChunk(Chunk & chunk)
// The peer group start is updated at the beginning of the loop,
// because current_row might now be past-the-end.
advanceRowNumber(current_row);
++current_row_number;
first_not_ready_row = current_row;
frame_ended = false;
frame_started = false;
@ -983,7 +1022,10 @@ void WindowTransform::appendChunk(Chunk & chunk)
prev_frame_start = partition_start;
prev_frame_end = partition_start;
assert(current_row == partition_start);
current_row_number = 1;
peer_group_start = partition_start;
peer_group_start_row_number = 1;
peer_group_number = 1;
// fmt::print(stderr, "reinitialize agg data at start of {}\n",
// new_partition_start);
@ -991,6 +1033,11 @@ void WindowTransform::appendChunk(Chunk & chunk)
// has started.
for (auto & ws : workspaces)
{
if (ws.window_function_impl)
{
continue;
}
const auto & f = ws.window_function;
const auto * a = f.aggregate_function.get();
auto * buf = ws.aggregate_function_state.data();
@ -1008,6 +1055,11 @@ void WindowTransform::appendChunk(Chunk & chunk)
for (auto & ws : workspaces)
{
if (ws.window_function_impl)
{
continue;
}
const auto & f = ws.window_function;
const auto * a = f.aggregate_function.get();
auto * buf = ws.aggregate_function_state.data();
@ -1175,5 +1227,116 @@ void WindowTransform::work()
}
}
// A basic implementation for a true window function. It pretends to be an
// aggregate function, but refuses to work as such.
struct WindowFunction
: public IAggregateFunctionHelper<WindowFunction>
, public IWindowFunction
{
std::string name;
WindowFunction(const std::string & name_, const DataTypes & argument_types_,
const Array & parameters_)
: IAggregateFunctionHelper<WindowFunction>(argument_types_, parameters_)
, name(name_)
{}
IWindowFunction * asWindowFunction() override { return this; }
[[noreturn]] void fail() const
{
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"The function '{}' can only be used as a window function, not as an aggregate function",
getName());
}
String getName() const override { return name; }
void create(AggregateDataPtr __restrict) const override { fail(); }
void destroy(AggregateDataPtr __restrict) const noexcept override {}
bool hasTrivialDestructor() const override { return true; }
size_t sizeOfData() const override { return 0; }
size_t alignOfData() const override { return 1; }
void add(AggregateDataPtr __restrict, const IColumn **, size_t, Arena *) const override { fail(); }
void merge(AggregateDataPtr __restrict, ConstAggregateDataPtr, Arena *) const override { fail(); }
void serialize(ConstAggregateDataPtr __restrict, WriteBuffer &) const override { fail(); }
void deserialize(AggregateDataPtr __restrict, ReadBuffer &, Arena *) const override { fail(); }
void insertResultInto(AggregateDataPtr __restrict, IColumn &, Arena *) const override { fail(); }
};
struct WindowFunctionRank final : public WindowFunction
{
WindowFunctionRank(const std::string & name_,
const DataTypes & argument_types_, const Array & parameters_)
: WindowFunction(name_, argument_types_, parameters_)
{}
DataTypePtr getReturnType() const override
{ return std::make_shared<DataTypeUInt64>(); }
void windowInsertResultInto(IColumn & to, const WindowTransform * transform) override
{
assert_cast<ColumnUInt64 &>(to).getData().push_back(
transform->peer_group_start_row_number);
}
};
struct WindowFunctionDenseRank final : public WindowFunction
{
WindowFunctionDenseRank(const std::string & name_,
const DataTypes & argument_types_, const Array & parameters_)
: WindowFunction(name_, argument_types_, parameters_)
{}
DataTypePtr getReturnType() const override
{ return std::make_shared<DataTypeUInt64>(); }
void windowInsertResultInto(IColumn & to, const WindowTransform * transform) override
{
assert_cast<ColumnUInt64 &>(to).getData().push_back(
transform->peer_group_number);
}
};
struct WindowFunctionRowNumber final : public WindowFunction
{
WindowFunctionRowNumber(const std::string & name_,
const DataTypes & argument_types_, const Array & parameters_)
: WindowFunction(name_, argument_types_, parameters_)
{}
DataTypePtr getReturnType() const override
{ return std::make_shared<DataTypeUInt64>(); }
void windowInsertResultInto(IColumn & to, const WindowTransform * transform) override
{
assert_cast<ColumnUInt64 &>(to).getData().push_back(
transform->current_row_number);
}
};
void registerWindowFunctions(AggregateFunctionFactory & factory)
{
factory.registerFunction("rank", [](const std::string & name,
const DataTypes & argument_types, const Array & parameters)
{
return std::make_shared<WindowFunctionRank>(name, argument_types,
parameters);
});
factory.registerFunction("dense_rank", [](const std::string & name,
const DataTypes & argument_types, const Array & parameters)
{
return std::make_shared<WindowFunctionDenseRank>(name, argument_types,
parameters);
});
factory.registerFunction("row_number", [](const std::string & name,
const DataTypes & argument_types, const Array & parameters)
{
return std::make_shared<WindowFunctionRowNumber>(name, argument_types,
parameters);
});
}
}

View File

@ -22,6 +22,10 @@ struct WindowFunctionWorkspace
WindowFunctionDescription window_function;
AlignedBuffer aggregate_function_state;
std::vector<size_t> argument_column_indices;
// This field is set for pure window functions. When set, we ignore the
// window_function.aggregate_function, and work through this interface
// instead.
IWindowFunction * window_function_impl = nullptr;
// Argument columns. Be careful, this is a per-block cache.
std::vector<const IColumn *> argument_columns;
@ -282,6 +286,11 @@ public:
// frames may be earlier.
RowNumber peer_group_start;
// Row and group numbers in partition for calculating rank() and friends.
uint64_t current_row_number = 1;
uint64_t peer_group_start_row_number = 1;
uint64_t peer_group_number = 1;
// The frame is [frame_start, frame_end) if frame_ended && frame_started,
// and unknown otherwise. Note that when we move to the next row, both the
// frame_start and the frame_end may jump forward by an unknown amount of

View File

@ -920,3 +920,45 @@ FROM numbers(2)
;
1 0
1 1
-- some true window functions -- rank and friends
select number, p, o,
count(*) over w,
rank() over w,
dense_rank() over w,
row_number() over w
from (select number, intDiv(number, 5) p, mod(number, 3) o
from numbers(31) order by o, number) t
window w as (partition by p order by o)
order by p, o, number
settings max_block_size = 2;
0 0 0 2 1 1 1
3 0 0 2 1 1 2
1 0 1 4 3 2 3
4 0 1 4 3 2 4
2 0 2 5 5 3 5
6 1 0 2 1 1 1
9 1 0 2 1 1 2
7 1 1 3 3 2 3
5 1 2 5 4 3 4
8 1 2 5 4 3 5
12 2 0 1 1 1 1
10 2 1 3 2 2 2
13 2 1 3 2 2 3
11 2 2 5 4 3 4
14 2 2 5 4 3 5
15 3 0 2 1 1 2
18 3 0 2 1 1 1
16 3 1 4 3 2 3
19 3 1 4 3 2 4
17 3 2 5 5 3 5
21 4 0 2 1 1 1
24 4 0 2 1 1 2
22 4 1 3 3 2 3
20 4 2 5 4 3 5
23 4 2 5 4 3 4
27 5 0 1 1 1 1
25 5 1 3 2 2 2
28 5 1 3 2 2 3
26 5 2 5 4 3 4
29 5 2 5 4 3 5
30 6 0 1 1 1 1

View File

@ -315,3 +315,15 @@ SELECT
max(number) OVER (ORDER BY number ASC NULLS FIRST)
FROM numbers(2)
;
-- some true window functions -- rank and friends
select number, p, o,
count(*) over w,
rank() over w,
dense_rank() over w,
row_number() over w
from (select number, intDiv(number, 5) p, mod(number, 3) o
from numbers(31) order by o, number) t
window w as (partition by p order by o)
order by p, o, number
settings max_block_size = 2;