Switch accumulator and array arguments

This commit is contained in:
Robert Schulze 2023-10-23 11:06:34 +00:00
parent 002f7ded74
commit b9b66e76dd
No known key found for this signature in database
GPG Key ID: 26703B55FB13728A
3 changed files with 49 additions and 41 deletions

View File

@ -1081,6 +1081,10 @@ Result:
└─────────────────────────────────────────────────────────────┘
```
**See also**
- [arrayFold](#arrayFold)
## arrayReduceInRanges
Applies an aggregate function to array elements in given ranges and returns an array containing the result corresponding to each range. The function will return the same result as multiple `arrayReduce(agg_func, arraySlice(arr1, index, length), ...)`.
@ -1138,7 +1142,7 @@ arrayFold(lambda_function, arr1, arr2, ..., accumulator)
Query:
``` sql
SELECT arrayFold( x,acc -> acc + x*2, [1, 2, 3, 4], toInt64(3)) AS res;
SELECT arrayFold( acc,x -> acc + x*2, [1, 2, 3, 4], toInt64(3)) AS res;
```
Result:
@ -1152,7 +1156,7 @@ Result:
**Example with the Fibonacci sequence**
```sql
SELECT arrayFold( x, acc -> (acc.2, acc.2 + acc.1), range(number), (1::Int64, 0::Int64)).1 AS fibonacci
SELECT arrayFold( acc,x -> (acc.2, acc.2 + acc.1), range(number), (1::Int64, 0::Int64)).1 AS fibonacci
FROM numbers(1,10);
┌─fibonacci─┐
@ -1169,6 +1173,9 @@ FROM numbers(1,10);
└───────────┘
```
**See also**
- [arrayReduce](#arrayReduce)
## arrayReverse(arr)

View File

@ -30,37 +30,37 @@ public:
void getLambdaArgumentTypes(DataTypes & arguments) const override
{
if (arguments.size() < 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator argument", getName());
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator", getName());
DataTypes nested_types(arguments.size() - 1);
for (size_t i = 0; i < nested_types.size() - 1; ++i)
DataTypes accumulator_and_array_types(arguments.size() - 1);
accumulator_and_array_types[0] = arguments.back();
for (size_t i = 1; i < accumulator_and_array_types.size(); ++i)
{
const auto * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i + 1]);
const auto * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i]);
if (!array_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be array, found {} instead", i + 2, getName(), arguments[i + 1]->getName());
nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType());
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be of type Array, found {} instead", i + 1, getName(), arguments[i]->getName());
accumulator_and_array_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType());
}
nested_types[nested_types.size() - 1] = arguments[arguments.size() - 1];
const auto * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
if (!function_type || function_type->getArgumentTypes().size() != nested_types.size())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for this overload of {} must be a function with {} arguments, found {} instead.",
getName(), nested_types.size(), arguments[0]->getName());
const auto * lambda_function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
if (!lambda_function_type || lambda_function_type->getArgumentTypes().size() != accumulator_and_array_types.size())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be a lambda function with {} arguments, found {} instead.",
getName(), accumulator_and_array_types.size(), arguments[0]->getName());
arguments[0] = std::make_shared<DataTypeFunction>(nested_types);
arguments[0] = std::make_shared<DataTypeFunction>(accumulator_and_array_types);
}
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() < 2)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least 2 arguments, passed: {}.", getName(), arguments.size());
if (arguments.size() < 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator", getName());
const auto * data_type_function = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
if (!data_type_function)
const auto * lambda_function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
if (!lambda_function_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());
auto accumulator_type = arguments.back().type;
auto lambda_type = data_type_function->getReturnType();
auto lambda_type = lambda_function_type->getReturnType();
if (!accumulator_type->equals(*lambda_type))
throw Exception(ErrorCodes::TYPE_MISMATCH,
"Return type of lambda function must be the same as the accumulator type, inferred return type of lambda: {}, inferred type of accumulator: {}",
@ -71,12 +71,12 @@ public:
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const auto & lambda_with_type_and_name = arguments[0];
const auto & lambda_function_with_type_and_name = arguments[0];
if (!lambda_with_type_and_name.column)
if (!lambda_function_with_type_and_name.column)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());
const auto * lambda_function = typeid_cast<const ColumnFunction *>(lambda_with_type_and_name.column.get());
const auto * lambda_function = typeid_cast<const ColumnFunction *>(lambda_function_with_type_and_name.column.get());
if (!lambda_function)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());
@ -85,6 +85,7 @@ public:
const ColumnArray * column_first_array = nullptr;
ColumnsWithTypeAndName arrays;
arrays.reserve(arguments.size() - 1);
/// Validate input types and get input array columns in convenient form
for (size_t i = 1; i < arguments.size() - 1; ++i)
{
@ -131,8 +132,7 @@ public:
if (rows_count == 0)
return arguments.back().column->convertToFullColumnIfConst()->cloneEmpty();
ColumnPtr current_column;
current_column = arguments.back().column->convertToFullColumnIfConst();
ColumnPtr current_column = arguments.back().column->convertToFullColumnIfConst();
MutableColumnPtr result_data = arguments.back().column->convertToFullColumnIfConst()->cloneEmpty();
size_t max_array_size = 0;
@ -198,9 +198,9 @@ public:
auto res_lambda = lambda_function->cloneResized(prev[1]->size());
auto * res_lambda_ptr = typeid_cast<ColumnFunction *>(res_lambda.get());
res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(prev[1]), arguments.back().type, arguments.back().name)}));
for (size_t i = 0; i < array_count; i++)
res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(data_arrays[i][ind]), arrays[i].type, arrays[i].name)}));
res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(prev[1]), arguments.back().type, arguments.back().name)}));
current_column = IColumn::mutate(res_lambda_ptr->reduce().column);
prev_size = current_column->size();

View File

@ -1,23 +1,24 @@
SELECT 'Negative tests';
SELECT arrayFold(); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT arrayFold(1); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT arrayFold(1, toUInt64(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( x,acc -> x, emptyArrayString(), toInt8(0)); -- { serverError TYPE_MISMATCH }
SELECT arrayFold( x,acc -> x, 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( x,y,acc -> x, [0, 1], 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( x,acc -> x, [0, 1], [2, 3], toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( x,y,acc -> x, [0, 1], [2, 3, 4], toUInt8(0)); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH }
SELECT arrayFold(1, toUInt64(0)); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT arrayFold(1, emptyArrayUInt64(), toUInt64(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( acc,x -> x, emptyArrayString(), toInt8(0)); -- { serverError TYPE_MISMATCH }
SELECT arrayFold( acc,x -> x, 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( acc,x,y -> x, [0, 1], 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( acc,x -> x, [0, 1], [2, 3], toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( acc,x,y -> x, [0, 1], [2, 3, 4], toUInt8(0)); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH }
SELECT 'Const arrays';
SELECT arrayFold( x,acc -> acc+x*2, [1, 2, 3, 4], toInt64(3));
SELECT arrayFold( x,acc -> acc+x*2, emptyArrayInt64(), toInt64(3));
SELECT arrayFold( x,y,acc -> acc+x*2+y*3, [1, 2, 3, 4], [5, 6, 7, 8], toInt64(3));
SELECT arrayFold( x,acc -> arrayPushBack(acc, x), [1, 2, 3, 4], emptyArrayInt64());
SELECT arrayFold( x,acc -> arrayPushFront(acc, x), [1, 2, 3, 4], emptyArrayInt64());
SELECT arrayFold( x,acc -> (arrayPushFront(acc.1, x),arrayPushBack(acc.2, x)), [1, 2, 3, 4], (emptyArrayInt64(), emptyArrayInt64()));
SELECT arrayFold( x,acc -> x%2 ? (arrayPushBack(acc.1, x), acc.2): (acc.1, arrayPushBack(acc.2, x)), [1, 2, 3, 4, 5, 6], (emptyArrayInt64(), emptyArrayInt64()));
SELECT arrayFold( acc,x -> acc+x*2, [1, 2, 3, 4], toInt64(3));
SELECT arrayFold( acc,x -> acc+x*2, emptyArrayInt64(), toInt64(3));
SELECT arrayFold( acc,x,y -> acc+x*2+y*3, [1, 2, 3, 4], [5, 6, 7, 8], toInt64(3));
SELECT arrayFold( acc,x -> arrayPushBack(acc, x), [1, 2, 3, 4], emptyArrayInt64());
SELECT arrayFold( acc,x -> arrayPushFront(acc, x), [1, 2, 3, 4], emptyArrayInt64());
SELECT arrayFold( acc,x -> (arrayPushFront(acc.1, x),arrayPushBack(acc.2, x)), [1, 2, 3, 4], (emptyArrayInt64(), emptyArrayInt64()));
SELECT arrayFold( acc,x -> x%2 ? (arrayPushBack(acc.1, x), acc.2): (acc.1, arrayPushBack(acc.2, x)), [1, 2, 3, 4, 5, 6], (emptyArrayInt64(), emptyArrayInt64()));
SELECT 'Non-const arrays';
SELECT arrayFold( x,acc -> acc+x, range(number), number) FROM system.numbers LIMIT 5;
SELECT arrayFold( x,acc -> arrayPushFront(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;
SELECT arrayFold( x,acc -> x%2 ? arrayPushFront(acc,x) : arrayPushBack(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;
SELECT arrayFold( acc,x -> acc+x, range(number), number) FROM system.numbers LIMIT 5;
SELECT arrayFold( acc,x -> arrayPushFront(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;
SELECT arrayFold( acc,x -> x%2 ? arrayPushFront(acc,x) : arrayPushBack(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;