diff --git a/src/Processors/Transforms/WindowTransform.cpp b/src/Processors/Transforms/WindowTransform.cpp index 3af72aff2cd..f5789b6065b 100644 --- a/src/Processors/Transforms/WindowTransform.cpp +++ b/src/Processors/Transforms/WindowTransform.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -2085,6 +2086,74 @@ struct WindowFunctionLagLeadInFrame final : public WindowFunction } }; +struct WindowFunctionNthValue final : public WindowFunction +{ + WindowFunctionNthValue(const std::string & name_, + const DataTypes & argument_types_, const Array & parameters_) + : WindowFunction(name_, argument_types_, parameters_) + { + if (!parameters.empty()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Function {} cannot be parameterized", name_); + } + + if (argument_types.size() != 2) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Function {} takes exactly two arguments", name_); + } + + if (!isInt64OrUInt64FieldType(argument_types[1]->getDefault().getType())) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "Offset must be an integer, '{}' given", + argument_types[1]->getName()); + } + } + + DataTypePtr getReturnType() const override { return argument_types[0]; } + + bool allocatesMemoryInArena() const override { return false; } + + void windowInsertResultInto(const WindowTransform * transform, + size_t function_index) override + { + const auto & current_block = transform->blockAt(transform->current_row); + IColumn & to = *current_block.output_columns[function_index]; + const auto & workspace = transform->workspaces[function_index]; + + int64_t offset = (*current_block.input_columns[ + workspace.argument_column_indices[1]])[ + transform->current_row.row].get(); + + /// Either overflow or really negative value, both is not acceptable. + if (offset <= 0) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, + "The offset for function {} must be in (0, {}], {} given", + getName(), INT64_MAX, offset); + } + + --offset; + const auto [target_row, offset_left] = transform->moveRowNumber(transform->frame_start, offset); + if (offset_left != 0 + || target_row < transform->frame_start + || transform->frame_end <= target_row) + { + // Offset is outside the frame. + to.insertDefault(); + } + else + { + // Offset is inside the frame. + to.insertFrom(*transform->blockAt(target_row).input_columns[ + workspace.argument_column_indices[0]], + target_row.row); + } + } +}; + void registerWindowFunctions(AggregateFunctionFactory & factory) { @@ -2136,6 +2205,13 @@ void registerWindowFunctions(AggregateFunctionFactory & factory) parameters); }, properties}, AggregateFunctionFactory::CaseInsensitive); + factory.registerFunction("nth_value", {[](const std::string & name, + const DataTypes & argument_types, const Array & parameters, const Settings *) + { + return std::make_shared( + name, argument_types, parameters); + }, properties}, AggregateFunctionFactory::CaseInsensitive); + factory.registerFunction("lagInFrame", {[](const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *) { diff --git a/tests/queries/0_stateless/01591_window_functions.reference b/tests/queries/0_stateless/01591_window_functions.reference index c766bf16f19..47da43399e6 100644 --- a/tests/queries/0_stateless/01591_window_functions.reference +++ b/tests/queries/0_stateless/01591_window_functions.reference @@ -1091,6 +1091,70 @@ order by number 7 6 8 8 7 9 9 8 9 +-- nth_value without specific frame range given +select + number, + nth_value(number, 1) over w as firstValue, + nth_value(number, 2) over w as secondValue, + nth_value(number, 3) over w as thirdValue, + nth_value(number, 4) over w as fourthValue +from numbers(10) +window w as (order by number) +order by number +; +0 0 0 0 0 +1 0 1 0 0 +2 0 1 2 0 +3 0 1 2 3 +4 0 1 2 3 +5 0 1 2 3 +6 0 1 2 3 +7 0 1 2 3 +8 0 1 2 3 +9 0 1 2 3 +-- nth_value with frame range specified +select + number, + nth_value(number, 1) over w as firstValue, + nth_value(number, 2) over w as secondValue, + nth_value(number, 3) over w as thirdValue, + nth_value(number, 4) over w as fourthValue +from numbers(10) +window w as (order by number range between 1 preceding and 1 following) +order by number +; +0 0 1 0 0 +1 0 1 2 0 +2 1 2 3 0 +3 2 3 4 0 +4 3 4 5 0 +5 4 5 6 0 +6 5 6 7 0 +7 6 7 8 0 +8 7 8 9 0 +9 8 9 0 0 +-- to make nth_value return null for out-of-frame rows, cast the argument to +-- Nullable; otherwise, it returns default values. +SELECT + number, + nth_value(toNullable(number), 1) OVER w as firstValue, + nth_value(toNullable(number), 3) OVER w as thridValue +FROM numbers(5) +WINDOW w AS (ORDER BY number ASC) +; +0 0 \N +1 0 \N +2 0 2 +3 0 2 +4 0 2 +-- nth_value UBsan +SELECT nth_value(1, -1) OVER (); -- { serverError BAD_ARGUMENTS } +SELECT nth_value(1, 0) OVER (); -- { serverError BAD_ARGUMENTS } +SELECT nth_value(1, /* INT64_MAX+1 */ 0x7fffffffffffffff+1) OVER (); -- { serverError BAD_ARGUMENTS } +SELECT nth_value(1, /* INT64_MAX */ 0x7fffffffffffffff) OVER (); +0 +SELECT nth_value(1, 1) OVER (); +1 -- lagInFrame UBsan SELECT lagInFrame(1, -1) OVER (); -- { serverError BAD_ARGUMENTS } SELECT lagInFrame(1, 0) OVER (); @@ -1109,6 +1173,12 @@ SELECT leadInFrame(1, /* INT64_MAX */ 0x7fffffffffffffff) OVER (); 0 SELECT leadInFrame(1, 1) OVER (); 0 +-- nth_value Msan +SELECT nth_value(1, '') OVER (); -- { serverError BAD_ARGUMENTS } +-- lagInFrame Msan +SELECT lagInFrame(1, '') OVER (); -- { serverError BAD_ARGUMENTS } +-- leadInFrame Msan +SELECT leadInFrame(1, '') OVER (); -- { serverError BAD_ARGUMENTS } -- In this case, we had a problem with PartialSortingTransform returning zero-row -- chunks for input chunks w/o columns. select count() over () from numbers(4) where number < 2; diff --git a/tests/queries/0_stateless/01591_window_functions.sql b/tests/queries/0_stateless/01591_window_functions.sql index 4a900045c6d..31cfa181f9c 100644 --- a/tests/queries/0_stateless/01591_window_functions.sql +++ b/tests/queries/0_stateless/01591_window_functions.sql @@ -403,6 +403,47 @@ window w as (order by number range between 1 preceding and 1 following) order by number ; +-- nth_value without specific frame range given +select + number, + nth_value(number, 1) over w as firstValue, + nth_value(number, 2) over w as secondValue, + nth_value(number, 3) over w as thirdValue, + nth_value(number, 4) over w as fourthValue +from numbers(10) +window w as (order by number) +order by number +; + +-- nth_value with frame range specified +select + number, + nth_value(number, 1) over w as firstValue, + nth_value(number, 2) over w as secondValue, + nth_value(number, 3) over w as thirdValue, + nth_value(number, 4) over w as fourthValue +from numbers(10) +window w as (order by number range between 1 preceding and 1 following) +order by number +; + +-- to make nth_value return null for out-of-frame rows, cast the argument to +-- Nullable; otherwise, it returns default values. +SELECT + number, + nth_value(toNullable(number), 1) OVER w as firstValue, + nth_value(toNullable(number), 3) OVER w as thridValue +FROM numbers(5) +WINDOW w AS (ORDER BY number ASC) +; + +-- nth_value UBsan +SELECT nth_value(1, -1) OVER (); -- { serverError BAD_ARGUMENTS } +SELECT nth_value(1, 0) OVER (); -- { serverError BAD_ARGUMENTS } +SELECT nth_value(1, /* INT64_MAX+1 */ 0x7fffffffffffffff+1) OVER (); -- { serverError BAD_ARGUMENTS } +SELECT nth_value(1, /* INT64_MAX */ 0x7fffffffffffffff) OVER (); +SELECT nth_value(1, 1) OVER (); + -- lagInFrame UBsan SELECT lagInFrame(1, -1) OVER (); -- { serverError BAD_ARGUMENTS } SELECT lagInFrame(1, 0) OVER (); @@ -417,6 +458,15 @@ SELECT leadInFrame(1, /* INT64_MAX+1 */ 0x7fffffffffffffff+1) OVER (); -- { serv SELECT leadInFrame(1, /* INT64_MAX */ 0x7fffffffffffffff) OVER (); SELECT leadInFrame(1, 1) OVER (); +-- nth_value Msan +SELECT nth_value(1, '') OVER (); -- { serverError BAD_ARGUMENTS } + +-- lagInFrame Msan +SELECT lagInFrame(1, '') OVER (); -- { serverError BAD_ARGUMENTS } + +-- leadInFrame Msan +SELECT leadInFrame(1, '') OVER (); -- { serverError BAD_ARGUMENTS } + -- In this case, we had a problem with PartialSortingTransform returning zero-row -- chunks for input chunks w/o columns. select count() over () from numbers(4) where number < 2;