From b9b66e76ddd33f45359b661d19d6b79801d6f7cc Mon Sep 17 00:00:00 2001 From: Robert Schulze Date: Mon, 23 Oct 2023 11:06:34 +0000 Subject: [PATCH] Switch accumulator and array arguments --- .../functions/array-functions.md | 11 ++++- src/Functions/array/arrayFold.cpp | 46 +++++++++---------- .../queries/0_stateless/02718_array_fold.sql | 33 ++++++------- 3 files changed, 49 insertions(+), 41 deletions(-) diff --git a/docs/en/sql-reference/functions/array-functions.md b/docs/en/sql-reference/functions/array-functions.md index 6e460a64bcf..40bfb65e4e8 100644 --- a/docs/en/sql-reference/functions/array-functions.md +++ b/docs/en/sql-reference/functions/array-functions.md @@ -1081,6 +1081,10 @@ Result: └─────────────────────────────────────────────────────────────┘ ``` +**See also** + +- [arrayFold](#arrayFold) + ## arrayReduceInRanges Applies an aggregate function to array elements in given ranges and returns an array containing the result corresponding to each range. The function will return the same result as multiple `arrayReduce(agg_func, arraySlice(arr1, index, length), ...)`. @@ -1138,7 +1142,7 @@ arrayFold(lambda_function, arr1, arr2, ..., accumulator) Query: ``` sql -SELECT arrayFold( x,acc -> acc + x*2, [1, 2, 3, 4], toInt64(3)) AS res; +SELECT arrayFold( acc,x -> acc + x*2, [1, 2, 3, 4], toInt64(3)) AS res; ``` Result: @@ -1152,7 +1156,7 @@ Result: **Example with the Fibonacci sequence** ```sql -SELECT arrayFold( x, acc -> (acc.2, acc.2 + acc.1), range(number), (1::Int64, 0::Int64)).1 AS fibonacci +SELECT arrayFold( acc,x -> (acc.2, acc.2 + acc.1), range(number), (1::Int64, 0::Int64)).1 AS fibonacci FROM numbers(1,10); ┌─fibonacci─┐ @@ -1169,6 +1173,9 @@ FROM numbers(1,10); └───────────┘ ``` +**See also** + +- [arrayReduce](#arrayReduce) ## arrayReverse(arr) diff --git a/src/Functions/array/arrayFold.cpp b/src/Functions/array/arrayFold.cpp index 94ed5d59ca9..b5b650e7289 100644 --- a/src/Functions/array/arrayFold.cpp +++ b/src/Functions/array/arrayFold.cpp @@ -30,37 +30,37 @@ public: void getLambdaArgumentTypes(DataTypes & arguments) const override { if (arguments.size() < 3) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator argument", getName()); + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator", getName()); - DataTypes nested_types(arguments.size() - 1); - for (size_t i = 0; i < nested_types.size() - 1; ++i) + DataTypes accumulator_and_array_types(arguments.size() - 1); + accumulator_and_array_types[0] = arguments.back(); + for (size_t i = 1; i < accumulator_and_array_types.size(); ++i) { - const auto * array_type = checkAndGetDataType(&*arguments[i + 1]); + const auto * array_type = checkAndGetDataType(&*arguments[i]); if (!array_type) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be array, found {} instead", i + 2, getName(), arguments[i + 1]->getName()); - nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType()); + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be of type Array, found {} instead", i + 1, getName(), arguments[i]->getName()); + accumulator_and_array_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType()); } - nested_types[nested_types.size() - 1] = arguments[arguments.size() - 1]; - const auto * function_type = checkAndGetDataType(arguments[0].get()); - if (!function_type || function_type->getArgumentTypes().size() != nested_types.size()) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for this overload of {} must be a function with {} arguments, found {} instead.", - getName(), nested_types.size(), arguments[0]->getName()); + const auto * lambda_function_type = checkAndGetDataType(arguments[0].get()); + if (!lambda_function_type || lambda_function_type->getArgumentTypes().size() != accumulator_and_array_types.size()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be a lambda function with {} arguments, found {} instead.", + getName(), accumulator_and_array_types.size(), arguments[0]->getName()); - arguments[0] = std::make_shared(nested_types); + arguments[0] = std::make_shared(accumulator_and_array_types); } DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { - if (arguments.size() < 2) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least 2 arguments, passed: {}.", getName(), arguments.size()); + if (arguments.size() < 3) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator", getName()); - const auto * data_type_function = checkAndGetDataType(arguments[0].type.get()); - if (!data_type_function) + const auto * lambda_function_type = checkAndGetDataType(arguments[0].type.get()); + if (!lambda_function_type) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName()); auto accumulator_type = arguments.back().type; - auto lambda_type = data_type_function->getReturnType(); + auto lambda_type = lambda_function_type->getReturnType(); if (!accumulator_type->equals(*lambda_type)) throw Exception(ErrorCodes::TYPE_MISMATCH, "Return type of lambda function must be the same as the accumulator type, inferred return type of lambda: {}, inferred type of accumulator: {}", @@ -71,12 +71,12 @@ public: ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { - const auto & lambda_with_type_and_name = arguments[0]; + const auto & lambda_function_with_type_and_name = arguments[0]; - if (!lambda_with_type_and_name.column) + if (!lambda_function_with_type_and_name.column) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName()); - const auto * lambda_function = typeid_cast(lambda_with_type_and_name.column.get()); + const auto * lambda_function = typeid_cast(lambda_function_with_type_and_name.column.get()); if (!lambda_function) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName()); @@ -85,6 +85,7 @@ public: const ColumnArray * column_first_array = nullptr; ColumnsWithTypeAndName arrays; arrays.reserve(arguments.size() - 1); + /// Validate input types and get input array columns in convenient form for (size_t i = 1; i < arguments.size() - 1; ++i) { @@ -131,8 +132,7 @@ public: if (rows_count == 0) return arguments.back().column->convertToFullColumnIfConst()->cloneEmpty(); - ColumnPtr current_column; - current_column = arguments.back().column->convertToFullColumnIfConst(); + ColumnPtr current_column = arguments.back().column->convertToFullColumnIfConst(); MutableColumnPtr result_data = arguments.back().column->convertToFullColumnIfConst()->cloneEmpty(); size_t max_array_size = 0; @@ -198,9 +198,9 @@ public: auto res_lambda = lambda_function->cloneResized(prev[1]->size()); auto * res_lambda_ptr = typeid_cast(res_lambda.get()); + res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(prev[1]), arguments.back().type, arguments.back().name)})); for (size_t i = 0; i < array_count; i++) res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(data_arrays[i][ind]), arrays[i].type, arrays[i].name)})); - res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(prev[1]), arguments.back().type, arguments.back().name)})); current_column = IColumn::mutate(res_lambda_ptr->reduce().column); prev_size = current_column->size(); diff --git a/tests/queries/0_stateless/02718_array_fold.sql b/tests/queries/0_stateless/02718_array_fold.sql index 7f20602a371..0486a5ce2e3 100644 --- a/tests/queries/0_stateless/02718_array_fold.sql +++ b/tests/queries/0_stateless/02718_array_fold.sql @@ -1,23 +1,24 @@ SELECT 'Negative tests'; SELECT arrayFold(); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } SELECT arrayFold(1); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } -SELECT arrayFold(1, toUInt64(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayFold( x,acc -> x, emptyArrayString(), toInt8(0)); -- { serverError TYPE_MISMATCH } -SELECT arrayFold( x,acc -> x, 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayFold( x,y,acc -> x, [0, 1], 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayFold( x,acc -> x, [0, 1], [2, 3], toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayFold( x,y,acc -> x, [0, 1], [2, 3, 4], toUInt8(0)); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH } +SELECT arrayFold(1, toUInt64(0)); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH } +SELECT arrayFold(1, emptyArrayUInt64(), toUInt64(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayFold( acc,x -> x, emptyArrayString(), toInt8(0)); -- { serverError TYPE_MISMATCH } +SELECT arrayFold( acc,x -> x, 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayFold( acc,x,y -> x, [0, 1], 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayFold( acc,x -> x, [0, 1], [2, 3], toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT arrayFold( acc,x,y -> x, [0, 1], [2, 3, 4], toUInt8(0)); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH } SELECT 'Const arrays'; -SELECT arrayFold( x,acc -> acc+x*2, [1, 2, 3, 4], toInt64(3)); -SELECT arrayFold( x,acc -> acc+x*2, emptyArrayInt64(), toInt64(3)); -SELECT arrayFold( x,y,acc -> acc+x*2+y*3, [1, 2, 3, 4], [5, 6, 7, 8], toInt64(3)); -SELECT arrayFold( x,acc -> arrayPushBack(acc, x), [1, 2, 3, 4], emptyArrayInt64()); -SELECT arrayFold( x,acc -> arrayPushFront(acc, x), [1, 2, 3, 4], emptyArrayInt64()); -SELECT arrayFold( x,acc -> (arrayPushFront(acc.1, x),arrayPushBack(acc.2, x)), [1, 2, 3, 4], (emptyArrayInt64(), emptyArrayInt64())); -SELECT arrayFold( x,acc -> x%2 ? (arrayPushBack(acc.1, x), acc.2): (acc.1, arrayPushBack(acc.2, x)), [1, 2, 3, 4, 5, 6], (emptyArrayInt64(), emptyArrayInt64())); +SELECT arrayFold( acc,x -> acc+x*2, [1, 2, 3, 4], toInt64(3)); +SELECT arrayFold( acc,x -> acc+x*2, emptyArrayInt64(), toInt64(3)); +SELECT arrayFold( acc,x,y -> acc+x*2+y*3, [1, 2, 3, 4], [5, 6, 7, 8], toInt64(3)); +SELECT arrayFold( acc,x -> arrayPushBack(acc, x), [1, 2, 3, 4], emptyArrayInt64()); +SELECT arrayFold( acc,x -> arrayPushFront(acc, x), [1, 2, 3, 4], emptyArrayInt64()); +SELECT arrayFold( acc,x -> (arrayPushFront(acc.1, x),arrayPushBack(acc.2, x)), [1, 2, 3, 4], (emptyArrayInt64(), emptyArrayInt64())); +SELECT arrayFold( acc,x -> x%2 ? (arrayPushBack(acc.1, x), acc.2): (acc.1, arrayPushBack(acc.2, x)), [1, 2, 3, 4, 5, 6], (emptyArrayInt64(), emptyArrayInt64())); SELECT 'Non-const arrays'; -SELECT arrayFold( x,acc -> acc+x, range(number), number) FROM system.numbers LIMIT 5; -SELECT arrayFold( x,acc -> arrayPushFront(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5; -SELECT arrayFold( x,acc -> x%2 ? arrayPushFront(acc,x) : arrayPushBack(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5; +SELECT arrayFold( acc,x -> acc+x, range(number), number) FROM system.numbers LIMIT 5; +SELECT arrayFold( acc,x -> arrayPushFront(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5; +SELECT arrayFold( acc,x -> x%2 ? arrayPushFront(acc,x) : arrayPushBack(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;