lag/lead stubs + cleanup

This commit is contained in:
Alexander Kuzmenkov 2021-02-11 18:07:42 +03:00
parent 525400bc41
commit ecbcf47f28
5 changed files with 122 additions and 27 deletions

View File

@ -540,7 +540,10 @@ void ExpressionAnalyzer::makeWindowDescriptions(ActionsDAGPtr actions)
!context.getSettingsRef().allow_experimental_window_functions)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED,
"Window functions are not implemented (while processing '{}')",
"The support for window functions is experimental and will change"
" in backwards-incompatible ways in the future releases. Set"
" allow_experimental_window_functions = 1 to enable it."
" While processing '{}'",
syntax->window_function_asts[0]->formatForErrorMessage());
}

View File

@ -1,17 +1,18 @@
#include <Processors/Transforms/WindowTransform.h>
#include <Interpreters/ExpressionActions.h>
#include <Common/Arena.h>
#include <DataTypes/DataTypesNumber.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Common/Arena.h>
#include <DataTypes/DataTypesNumber.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/convertFieldToType.h>
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int NOT_IMPLEMENTED;
}
@ -24,6 +25,7 @@ class IWindowFunction {
public:
virtual ~IWindowFunction() {}
// Must insert the result for current_row.
virtual void windowInsertResultInto(IColumn & to, const WindowTransform * transform) = 0;
};
@ -140,18 +142,15 @@ WindowTransform::WindowTransform(const Block & input_header_,
for (const auto & f : functions)
{
WindowFunctionWorkspace workspace;
workspace.window_function = f;
const auto & aggregate_function
= workspace.window_function.aggregate_function;
workspace.aggregate_function = f.aggregate_function;
const auto & aggregate_function = workspace.aggregate_function;
if (!arena && aggregate_function->allocatesMemoryInArena())
{
arena = std::make_unique<Arena>();
}
workspace.argument_column_indices.reserve(
workspace.window_function.argument_names.size());
for (const auto & argument_name : workspace.window_function.argument_names)
workspace.argument_column_indices.reserve(f.argument_names.size());
for (const auto & argument_name : f.argument_names)
{
workspace.argument_column_indices.push_back(
input_header.getPositionByName(argument_name));
@ -205,7 +204,7 @@ WindowTransform::~WindowTransform()
{
if (!ws.window_function_impl)
{
ws.window_function.aggregate_function->destroy(
ws.aggregate_function->destroy(
ws.aggregate_function_state.data());
}
}
@ -785,7 +784,7 @@ void WindowTransform::updateAggregationState()
continue;
}
const auto * a = ws.window_function.aggregate_function.get();
const auto * a = ws.aggregate_function.get();
auto * buf = ws.aggregate_function_state.data();
if (reset_aggregation)
@ -856,8 +855,7 @@ void WindowTransform::writeOutCurrentRow()
}
else
{
const auto & f = ws.window_function;
const auto * a = f.aggregate_function.get();
const auto * a = ws.aggregate_function.get();
auto * buf = ws.aggregate_function_state.data();
// FIXME does it also allocate the result on the arena?
// We'll have to pass it out with blocks then...
@ -891,8 +889,8 @@ void WindowTransform::appendChunk(Chunk & chunk)
->convertToFullColumnIfConst();
}
block.output_columns.push_back(ws.window_function.aggregate_function
->getReturnType()->createColumn());
block.output_columns.push_back(ws.aggregate_function->getReturnType()
->createColumn());
}
// Even in case of `count() over ()` we should have a dummy input column.
@ -1038,8 +1036,7 @@ void WindowTransform::appendChunk(Chunk & chunk)
continue;
}
const auto & f = ws.window_function;
const auto * a = f.aggregate_function.get();
const auto * a = ws.aggregate_function.get();
auto * buf = ws.aggregate_function_state.data();
a->destroy(buf);
@ -1060,8 +1057,7 @@ void WindowTransform::appendChunk(Chunk & chunk)
continue;
}
const auto & f = ws.window_function;
const auto * a = f.aggregate_function.get();
const auto * a = ws.aggregate_function.get();
auto * buf = ws.aggregate_function_state.data();
a->create(buf);
@ -1314,6 +1310,71 @@ struct WindowFunctionRowNumber final : public WindowFunction
}
};
struct WindowFunctionLagLead final : public WindowFunction
{
bool is_lag = false;
// Always positive.
uint64_t offset_rows = 1;
Field default_value;
WindowFunctionLagLead(const std::string & name_,
const DataTypes & argument_types_, const Array & parameters_,
bool is_lag_)
: WindowFunction(name_, argument_types_, parameters_)
, is_lag(is_lag_)
{
// offset and default are in parameters
if (argument_types.size() != 1)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"The window function {} must have exactly one argument -- the value column. The offset and the default value must be specified as parameters, i.e. `{}(offset, default)(column)`",
getName(), getName());
}
if (parameters.size() > 2)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"The window function {} accepts at most two parameters, {} given",
getName(), parameters.size());
}
if (parameters.size() >= 1)
{
if (!isInt64FieldType(parameters[0].getType())
|| parameters[0].get<Int64>() < 0)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"The first parameter of the window function {} must be a nonnegative integer specifying the number of offset rows. Got '{}' instead",
getName(), toString(parameters[0]));
}
offset_rows = parameters[0].get<UInt64>();
}
if (parameters.size() >= 2)
{
default_value = convertFieldToTypeOrThrow(parameters[1],
*argument_types[0]);
}
}
DataTypePtr getReturnType() const override { return argument_types[0]; }
void windowInsertResultInto(IColumn &, const WindowTransform *) override
{
// These functions are a mess... they ignore the frame, so we need to
// either materialize the whole partition (not practical if it's big),
// or track a separate frame for these functions, which would make the
// window transform completely impenetrable to human mind. Our best bet
// is probably rewriting, say, `lag(value, offset)` to
// `any(value) over rows between offset preceding and offset preceding`,
// at the query planning stage. We can keep this class as a stub for
// parsing, anyway.
throw Exception(ErrorCodes::NOT_IMPLEMENTED,
"The window function {} is not implemented",
getName());
}
};
void registerWindowFunctions(AggregateFunctionFactory & factory)
{
@ -1337,6 +1398,20 @@ void registerWindowFunctions(AggregateFunctionFactory & factory)
return std::make_shared<WindowFunctionRowNumber>(name, argument_types,
parameters);
});
factory.registerFunction("lag", [](const std::string & name,
const DataTypes & argument_types, const Array & parameters)
{
return std::make_shared<WindowFunctionLagLead>(name, argument_types,
parameters, true /* is_lag */);
});
factory.registerFunction("lead", [](const std::string & name,
const DataTypes & argument_types, const Array & parameters)
{
return std::make_shared<WindowFunctionLagLead>(name, argument_types,
parameters, false /* is_lag */);
});
}
}

View File

@ -19,14 +19,18 @@ class Arena;
// Runtime data for computing one window function.
struct WindowFunctionWorkspace
{
WindowFunctionDescription window_function;
AlignedBuffer aggregate_function_state;
std::vector<size_t> argument_column_indices;
AggregateFunctionPtr aggregate_function;
// This field is set for pure window functions. When set, we ignore the
// window_function.aggregate_function, and work through this interface
// instead.
IWindowFunction * window_function_impl = nullptr;
std::vector<size_t> argument_column_indices;
// Will not be initialized for a pure window function.
AlignedBuffer aggregate_function_state;
// Argument columns. Be careful, this is a per-block cache.
std::vector<const IColumn *> argument_columns;
uint64_t cached_block_number = std::numeric_limits<uint64_t>::max();

View File

@ -962,3 +962,9 @@ settings max_block_size = 2;
26 5 2 5 4 3 4
29 5 2 5 4 3 5
30 6 0 1 1 1 1
-- very bad functions, not implemented yet
select
lag(1, 5)(number) over (),
lead(2)(number) over (),
lag(number) over ()
from numbers(2); -- { serverError 48 }

View File

@ -327,3 +327,10 @@ from (select number, intDiv(number, 5) p, mod(number, 3) o
window w as (partition by p order by o)
order by p, o, number
settings max_block_size = 2;
-- very bad functions, not implemented yet
select
lag(1, 5)(number) over (),
lead(2)(number) over (),
lag(number) over ()
from numbers(2); -- { serverError 48 }