From b36a93a25d880ee95c6167c228d7ef8b9d97fcc2 Mon Sep 17 00:00:00 2001 From: Robert Schulze Date: Thu, 14 Dec 2023 09:52:29 +0000 Subject: [PATCH] Revert "Remove `arrayFold`" This reverts commit 15dc0ed610998b847cb0752f5721c55d538fb629. --- .../functions/array-functions.md | 54 ++++ src/Functions/array/arrayFold.cpp | 236 ++++++++++++++++++ tests/performance/array_fold.xml | 5 + .../0_stateless/02718_array_fold.reference | 25 ++ .../queries/0_stateless/02718_array_fold.sql | 24 ++ 5 files changed, 344 insertions(+) create mode 100644 src/Functions/array/arrayFold.cpp create mode 100644 tests/performance/array_fold.xml create mode 100644 tests/queries/0_stateless/02718_array_fold.reference create mode 100644 tests/queries/0_stateless/02718_array_fold.sql diff --git a/docs/en/sql-reference/functions/array-functions.md b/docs/en/sql-reference/functions/array-functions.md index a058e1db6b4..00efa63c960 100644 --- a/docs/en/sql-reference/functions/array-functions.md +++ b/docs/en/sql-reference/functions/array-functions.md @@ -1081,6 +1081,10 @@ Result: └─────────────────────────────────────────────────────────────┘ ``` +**See also** + +- [arrayFold](#arrayfold) + ## arrayReduceInRanges Applies an aggregate function to array elements in given ranges and returns an array containing the result corresponding to each range. The function will return the same result as multiple `arrayReduce(agg_func, arraySlice(arr1, index, length), ...)`. @@ -1123,6 +1127,56 @@ Result: └─────────────────────────────┘ ``` +## arrayFold + +Applies a lambda function to one or more equally-sized arrays and collects the result in an accumulator. + +**Syntax** + +``` sql +arrayFold(lambda_function, arr1, arr2, ..., accumulator) +``` + +**Example** + +Query: + +``` sql +SELECT arrayFold( acc,x -> acc + x*2, [1, 2, 3, 4], toInt64(3)) AS res; +``` + +Result: + +``` text +┌─res─┐ +│ 23 │ +└─────┘ +``` + +**Example with the Fibonacci sequence** + +```sql +SELECT arrayFold( acc,x -> (acc.2, acc.2 + acc.1), range(number), (1::Int64, 0::Int64)).1 AS fibonacci +FROM numbers(1,10); + +┌─fibonacci─┐ +│ 0 │ +│ 1 │ +│ 1 │ +│ 2 │ +│ 3 │ +│ 5 │ +│ 8 │ +│ 13 │ +│ 21 │ +│ 34 │ +└───────────┘ +``` + +**See also** + +- [arrayReduce](#arrayreduce) + ## arrayReverse(arr) Returns an array of the same size as the original array containing the elements in reverse order. diff --git a/src/Functions/array/arrayFold.cpp b/src/Functions/array/arrayFold.cpp new file mode 100644 index 00000000000..b5b650e7289 --- /dev/null +++ b/src/Functions/array/arrayFold.cpp @@ -0,0 +1,236 @@ +#include "FunctionArrayMapped.h" +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_COLUMN; + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; + extern const int TYPE_MISMATCH; +} + +/** + * arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, accum_initial) - apply the expression to each element of the array (or set of arrays). + */ +class ArrayFold : public IFunction +{ +public: + static constexpr auto name = "arrayFold"; + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + + bool isVariadic() const override { return true; } + size_t getNumberOfArguments() const override { return 0; } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + + void getLambdaArgumentTypes(DataTypes & arguments) const override + { + if (arguments.size() < 3) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator", getName()); + + DataTypes accumulator_and_array_types(arguments.size() - 1); + accumulator_and_array_types[0] = arguments.back(); + for (size_t i = 1; i < accumulator_and_array_types.size(); ++i) + { + const auto * array_type = checkAndGetDataType(&*arguments[i]); + if (!array_type) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be of type Array, found {} instead", i + 1, getName(), arguments[i]->getName()); + accumulator_and_array_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType()); + } + + const auto * lambda_function_type = checkAndGetDataType(arguments[0].get()); + if (!lambda_function_type || lambda_function_type->getArgumentTypes().size() != accumulator_and_array_types.size()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be a lambda function with {} arguments, found {} instead.", + getName(), accumulator_and_array_types.size(), arguments[0]->getName()); + + arguments[0] = std::make_shared(accumulator_and_array_types); + } + + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override + { + if (arguments.size() < 3) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator", getName()); + + const auto * lambda_function_type = checkAndGetDataType(arguments[0].type.get()); + if (!lambda_function_type) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName()); + + auto accumulator_type = arguments.back().type; + auto lambda_type = lambda_function_type->getReturnType(); + if (!accumulator_type->equals(*lambda_type)) + throw Exception(ErrorCodes::TYPE_MISMATCH, + "Return type of lambda function must be the same as the accumulator type, inferred return type of lambda: {}, inferred type of accumulator: {}", + lambda_type->getName(), accumulator_type->getName()); + + return accumulator_type; + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + const auto & lambda_function_with_type_and_name = arguments[0]; + + if (!lambda_function_with_type_and_name.column) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName()); + + const auto * lambda_function = typeid_cast(lambda_function_with_type_and_name.column.get()); + if (!lambda_function) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName()); + + ColumnPtr offsets_column; + ColumnPtr column_first_array_ptr; + const ColumnArray * column_first_array = nullptr; + ColumnsWithTypeAndName arrays; + arrays.reserve(arguments.size() - 1); + + /// Validate input types and get input array columns in convenient form + for (size_t i = 1; i < arguments.size() - 1; ++i) + { + const auto & array_with_type_and_name = arguments[i]; + ColumnPtr column_array_ptr = array_with_type_and_name.column; + const auto * column_array = checkAndGetColumn(column_array_ptr.get()); + if (!column_array) + { + const ColumnConst * column_const_array = checkAndGetColumnConst(column_array_ptr.get()); + if (!column_const_array) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Expected array column, found {}", column_array_ptr->getName()); + column_array_ptr = recursiveRemoveLowCardinality(column_const_array->convertToFullColumn()); + column_array = checkAndGetColumn(column_array_ptr.get()); + } + + const DataTypePtr & array_type_ptr = array_with_type_and_name.type; + const auto * array_type = checkAndGetDataType(array_type_ptr.get()); + if (!array_type) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Expected array type, found {}", array_type_ptr->getName()); + + if (!offsets_column) + offsets_column = column_array->getOffsetsPtr(); + else + { + /// The first condition is optimization: do not compare data if the pointers are equal. + if (column_array->getOffsetsPtr() != offsets_column + && column_array->getOffsets() != typeid_cast(*offsets_column).getData()) + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Arrays passed to {} must have equal size", getName()); + } + if (i == 1) + { + column_first_array_ptr = column_array_ptr; + column_first_array = column_array; + } + arrays.emplace_back(ColumnWithTypeAndName(column_array->getDataPtr(), + recursiveRemoveLowCardinality(array_type->getNestedType()), + array_with_type_and_name.name)); + } + + ssize_t rows_count = input_rows_count; + ssize_t data_row_count = arrays[0].column->size(); + size_t array_count = arrays.size(); + + if (rows_count == 0) + return arguments.back().column->convertToFullColumnIfConst()->cloneEmpty(); + + ColumnPtr current_column = arguments.back().column->convertToFullColumnIfConst(); + MutableColumnPtr result_data = arguments.back().column->convertToFullColumnIfConst()->cloneEmpty(); + + size_t max_array_size = 0; + const auto & offsets = column_first_array->getOffsets(); + + IColumn::Selector selector(data_row_count); + size_t cur_ind = 0; + ssize_t cur_arr = 0; + + /// skip to the first non empty array + if (data_row_count) + while (offsets[cur_arr] == 0) + ++cur_arr; + + /// selector[i] is an index that i_th data element has in an array it corresponds to + for (ssize_t i = 0; i < data_row_count; ++i) + { + selector[i] = cur_ind; + cur_ind++; + if (cur_ind > max_array_size) + max_array_size = cur_ind; + while (cur_arr < rows_count && cur_ind >= offsets[cur_arr] - offsets[cur_arr - 1]) + { + ++cur_arr; + cur_ind = 0; + } + } + + std::vector data_arrays; + data_arrays.resize(array_count); + + /// Split each data column to columns containing elements of only Nth index in array + if (max_array_size > 0) + for (size_t i = 0; i < array_count; ++i) + data_arrays[i] = arrays[i].column->scatter(max_array_size, selector); + + size_t prev_size = rows_count; + + IColumn::Permutation inverse_permutation(rows_count); + size_t inverse_permutation_count = 0; + + /// current_column after each iteration contains value of accumulator after applying values under indexes of arrays. + /// At each iteration only rows of current_column with arrays that still has unapplied elements are kept. + /// Discarded rows which contain finished calculations are added to result_data column and as we insert them we save their original row_number in inverse_permutation vector + for (size_t ind = 0; ind < max_array_size; ++ind) + { + IColumn::Selector prev_selector(prev_size); + size_t prev_ind = 0; + for (ssize_t irow = 0; irow < rows_count; ++irow) + { + if (offsets[irow] - offsets[irow - 1] > ind) + prev_selector[prev_ind++] = 1; + else if (offsets[irow] - offsets[irow - 1] == ind) + { + inverse_permutation[inverse_permutation_count++] = irow; + prev_selector[prev_ind++] = 0; + } + } + auto prev = current_column->scatter(2, prev_selector); + + result_data->insertRangeFrom(*(prev[0]), 0, prev[0]->size()); + + auto res_lambda = lambda_function->cloneResized(prev[1]->size()); + auto * res_lambda_ptr = typeid_cast(res_lambda.get()); + + res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(prev[1]), arguments.back().type, arguments.back().name)})); + for (size_t i = 0; i < array_count; i++) + res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(data_arrays[i][ind]), arrays[i].type, arrays[i].name)})); + + current_column = IColumn::mutate(res_lambda_ptr->reduce().column); + prev_size = current_column->size(); + } + + result_data->insertRangeFrom(*current_column, 0, current_column->size()); + for (ssize_t irow = 0; irow < rows_count; ++irow) + if (offsets[irow] - offsets[irow - 1] == max_array_size) + inverse_permutation[inverse_permutation_count++] = irow; + + /// We have result_data containing result for every row and inverse_permutation which contains indexes of rows in input it corresponds to. + /// Now we need to invert inverse_permuation and apply it to result_data to get rows in right order. + IColumn::Permutation perm(rows_count); + for (ssize_t i = 0; i < rows_count; i++) + perm[inverse_permutation[i]] = i; + return result_data->permute(perm, 0); + } + +private: + String getName() const override + { + return name; + } +}; + +REGISTER_FUNCTION(ArrayFold) +{ + factory.registerFunction(FunctionDocumentation{.description=R"( + Function arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, accum_initial) applies lambda function to a number of equally-sized arrays + and collects the result in an accumulator. + )", .examples{{"sum", "SELECT arrayFold(x,acc -> acc+x, [1,2,3,4], toInt64(1));", "11"}}, .categories{"Array"}}); +} +} diff --git a/tests/performance/array_fold.xml b/tests/performance/array_fold.xml new file mode 100644 index 00000000000..32bd45beb1e --- /dev/null +++ b/tests/performance/array_fold.xml @@ -0,0 +1,5 @@ + + SELECT arrayFold((acc, x) -> acc + x, range(number % 100), toUInt64(0)) from numbers(100000) Format Null + SELECT arrayFold((acc, x) -> acc + 1, range(number % 100), toUInt64(0)) from numbers(100000) Format Null + SELECT arrayFold((acc, x) -> acc + x, range(number), toUInt64(0)) from numbers(10000) Format Null + diff --git a/tests/queries/0_stateless/02718_array_fold.reference b/tests/queries/0_stateless/02718_array_fold.reference new file mode 100644 index 00000000000..4139232d145 --- /dev/null +++ b/tests/queries/0_stateless/02718_array_fold.reference @@ -0,0 +1,25 @@ +Negative tests +Const arrays +23 +3 +101 +[1,2,3,4] +[4,3,2,1] +([4,3,2,1],[1,2,3,4]) +([1,3,5],[2,4,6]) +Non-const arrays +0 +1 +3 +6 +10 +[] +[0] +[1,0] +[2,1,0] +[3,2,1,0] +[] +[0] +[1,0] +[1,0,2] +[3,1,0,2] diff --git a/tests/queries/0_stateless/02718_array_fold.sql b/tests/queries/0_stateless/02718_array_fold.sql new file mode 100644 index 00000000000..0486a5ce2e3 --- /dev/null +++ b/tests/queries/0_stateless/02718_array_fold.sql @@ -0,0 +1,24 @@ +SELECT 'Negative tests'; +SELECT arrayFold(); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +SELECT arrayFold(1); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +SELECT arrayFold(1, toUInt64(0)); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +SELECT arrayFold(1, emptyArrayUInt64(), toUInt64(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayFold( acc,x -> x, emptyArrayString(), toInt8(0)); -- { serverError TYPE_MISMATCH } +SELECT arrayFold( acc,x -> x, 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayFold( acc,x,y -> x, [0, 1], 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayFold( acc,x -> x, [0, 1], [2, 3], toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayFold( acc,x,y -> x, [0, 1], [2, 3, 4], toUInt8(0)); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH } + +SELECT 'Const arrays'; +SELECT arrayFold( acc,x -> acc+x*2, [1, 2, 3, 4], toInt64(3)); +SELECT arrayFold( acc,x -> acc+x*2, emptyArrayInt64(), toInt64(3)); +SELECT arrayFold( acc,x,y -> acc+x*2+y*3, [1, 2, 3, 4], [5, 6, 7, 8], toInt64(3)); +SELECT arrayFold( acc,x -> arrayPushBack(acc, x), [1, 2, 3, 4], emptyArrayInt64()); +SELECT arrayFold( acc,x -> arrayPushFront(acc, x), [1, 2, 3, 4], emptyArrayInt64()); +SELECT arrayFold( acc,x -> (arrayPushFront(acc.1, x),arrayPushBack(acc.2, x)), [1, 2, 3, 4], (emptyArrayInt64(), emptyArrayInt64())); +SELECT arrayFold( acc,x -> x%2 ? (arrayPushBack(acc.1, x), acc.2): (acc.1, arrayPushBack(acc.2, x)), [1, 2, 3, 4, 5, 6], (emptyArrayInt64(), emptyArrayInt64())); + +SELECT 'Non-const arrays'; +SELECT arrayFold( acc,x -> acc+x, range(number), number) FROM system.numbers LIMIT 5; +SELECT arrayFold( acc,x -> arrayPushFront(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5; +SELECT arrayFold( acc,x -> x%2 ? arrayPushFront(acc,x) : arrayPushBack(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;