mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-30 03:22:14 +00:00
Switch accumulator and array arguments
This commit is contained in:
parent
002f7ded74
commit
b9b66e76dd
@ -1081,6 +1081,10 @@ Result:
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**See also**
|
||||
|
||||
- [arrayFold](#arrayFold)
|
||||
|
||||
## arrayReduceInRanges
|
||||
|
||||
Applies an aggregate function to array elements in given ranges and returns an array containing the result corresponding to each range. The function will return the same result as multiple `arrayReduce(agg_func, arraySlice(arr1, index, length), ...)`.
|
||||
@ -1138,7 +1142,7 @@ arrayFold(lambda_function, arr1, arr2, ..., accumulator)
|
||||
Query:
|
||||
|
||||
``` sql
|
||||
SELECT arrayFold( x,acc -> acc + x*2, [1, 2, 3, 4], toInt64(3)) AS res;
|
||||
SELECT arrayFold( acc,x -> acc + x*2, [1, 2, 3, 4], toInt64(3)) AS res;
|
||||
```
|
||||
|
||||
Result:
|
||||
@ -1152,7 +1156,7 @@ Result:
|
||||
**Example with the Fibonacci sequence**
|
||||
|
||||
```sql
|
||||
SELECT arrayFold( x, acc -> (acc.2, acc.2 + acc.1), range(number), (1::Int64, 0::Int64)).1 AS fibonacci
|
||||
SELECT arrayFold( acc,x -> (acc.2, acc.2 + acc.1), range(number), (1::Int64, 0::Int64)).1 AS fibonacci
|
||||
FROM numbers(1,10);
|
||||
|
||||
┌─fibonacci─┐
|
||||
@ -1169,6 +1173,9 @@ FROM numbers(1,10);
|
||||
└───────────┘
|
||||
```
|
||||
|
||||
**See also**
|
||||
|
||||
- [arrayReduce](#arrayReduce)
|
||||
|
||||
## arrayReverse(arr)
|
||||
|
||||
|
@ -30,37 +30,37 @@ public:
|
||||
void getLambdaArgumentTypes(DataTypes & arguments) const override
|
||||
{
|
||||
if (arguments.size() < 3)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator argument", getName());
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator", getName());
|
||||
|
||||
DataTypes nested_types(arguments.size() - 1);
|
||||
for (size_t i = 0; i < nested_types.size() - 1; ++i)
|
||||
DataTypes accumulator_and_array_types(arguments.size() - 1);
|
||||
accumulator_and_array_types[0] = arguments.back();
|
||||
for (size_t i = 1; i < accumulator_and_array_types.size(); ++i)
|
||||
{
|
||||
const auto * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i + 1]);
|
||||
const auto * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i]);
|
||||
if (!array_type)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be array, found {} instead", i + 2, getName(), arguments[i + 1]->getName());
|
||||
nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType());
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be of type Array, found {} instead", i + 1, getName(), arguments[i]->getName());
|
||||
accumulator_and_array_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType());
|
||||
}
|
||||
nested_types[nested_types.size() - 1] = arguments[arguments.size() - 1];
|
||||
|
||||
const auto * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
|
||||
if (!function_type || function_type->getArgumentTypes().size() != nested_types.size())
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for this overload of {} must be a function with {} arguments, found {} instead.",
|
||||
getName(), nested_types.size(), arguments[0]->getName());
|
||||
const auto * lambda_function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
|
||||
if (!lambda_function_type || lambda_function_type->getArgumentTypes().size() != accumulator_and_array_types.size())
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be a lambda function with {} arguments, found {} instead.",
|
||||
getName(), accumulator_and_array_types.size(), arguments[0]->getName());
|
||||
|
||||
arguments[0] = std::make_shared<DataTypeFunction>(nested_types);
|
||||
arguments[0] = std::make_shared<DataTypeFunction>(accumulator_and_array_types);
|
||||
}
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
|
||||
{
|
||||
if (arguments.size() < 2)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least 2 arguments, passed: {}.", getName(), arguments.size());
|
||||
if (arguments.size() < 3)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator", getName());
|
||||
|
||||
const auto * data_type_function = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
|
||||
if (!data_type_function)
|
||||
const auto * lambda_function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
|
||||
if (!lambda_function_type)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());
|
||||
|
||||
auto accumulator_type = arguments.back().type;
|
||||
auto lambda_type = data_type_function->getReturnType();
|
||||
auto lambda_type = lambda_function_type->getReturnType();
|
||||
if (!accumulator_type->equals(*lambda_type))
|
||||
throw Exception(ErrorCodes::TYPE_MISMATCH,
|
||||
"Return type of lambda function must be the same as the accumulator type, inferred return type of lambda: {}, inferred type of accumulator: {}",
|
||||
@ -71,12 +71,12 @@ public:
|
||||
|
||||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
|
||||
{
|
||||
const auto & lambda_with_type_and_name = arguments[0];
|
||||
const auto & lambda_function_with_type_and_name = arguments[0];
|
||||
|
||||
if (!lambda_with_type_and_name.column)
|
||||
if (!lambda_function_with_type_and_name.column)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());
|
||||
|
||||
const auto * lambda_function = typeid_cast<const ColumnFunction *>(lambda_with_type_and_name.column.get());
|
||||
const auto * lambda_function = typeid_cast<const ColumnFunction *>(lambda_function_with_type_and_name.column.get());
|
||||
if (!lambda_function)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());
|
||||
|
||||
@ -85,6 +85,7 @@ public:
|
||||
const ColumnArray * column_first_array = nullptr;
|
||||
ColumnsWithTypeAndName arrays;
|
||||
arrays.reserve(arguments.size() - 1);
|
||||
|
||||
/// Validate input types and get input array columns in convenient form
|
||||
for (size_t i = 1; i < arguments.size() - 1; ++i)
|
||||
{
|
||||
@ -131,8 +132,7 @@ public:
|
||||
if (rows_count == 0)
|
||||
return arguments.back().column->convertToFullColumnIfConst()->cloneEmpty();
|
||||
|
||||
ColumnPtr current_column;
|
||||
current_column = arguments.back().column->convertToFullColumnIfConst();
|
||||
ColumnPtr current_column = arguments.back().column->convertToFullColumnIfConst();
|
||||
MutableColumnPtr result_data = arguments.back().column->convertToFullColumnIfConst()->cloneEmpty();
|
||||
|
||||
size_t max_array_size = 0;
|
||||
@ -198,9 +198,9 @@ public:
|
||||
auto res_lambda = lambda_function->cloneResized(prev[1]->size());
|
||||
auto * res_lambda_ptr = typeid_cast<ColumnFunction *>(res_lambda.get());
|
||||
|
||||
res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(prev[1]), arguments.back().type, arguments.back().name)}));
|
||||
for (size_t i = 0; i < array_count; i++)
|
||||
res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(data_arrays[i][ind]), arrays[i].type, arrays[i].name)}));
|
||||
res_lambda_ptr->appendArguments(std::vector({ColumnWithTypeAndName(std::move(prev[1]), arguments.back().type, arguments.back().name)}));
|
||||
|
||||
current_column = IColumn::mutate(res_lambda_ptr->reduce().column);
|
||||
prev_size = current_column->size();
|
||||
|
@ -1,23 +1,24 @@
|
||||
SELECT 'Negative tests';
|
||||
SELECT arrayFold(); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
|
||||
SELECT arrayFold(1); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
|
||||
SELECT arrayFold(1, toUInt64(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayFold( x,acc -> x, emptyArrayString(), toInt8(0)); -- { serverError TYPE_MISMATCH }
|
||||
SELECT arrayFold( x,acc -> x, 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayFold( x,y,acc -> x, [0, 1], 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayFold( x,acc -> x, [0, 1], [2, 3], toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayFold( x,y,acc -> x, [0, 1], [2, 3, 4], toUInt8(0)); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH }
|
||||
SELECT arrayFold(1, toUInt64(0)); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
|
||||
SELECT arrayFold(1, emptyArrayUInt64(), toUInt64(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayFold( acc,x -> x, emptyArrayString(), toInt8(0)); -- { serverError TYPE_MISMATCH }
|
||||
SELECT arrayFold( acc,x -> x, 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayFold( acc,x,y -> x, [0, 1], 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayFold( acc,x -> x, [0, 1], [2, 3], toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayFold( acc,x,y -> x, [0, 1], [2, 3, 4], toUInt8(0)); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH }
|
||||
|
||||
SELECT 'Const arrays';
|
||||
SELECT arrayFold( x,acc -> acc+x*2, [1, 2, 3, 4], toInt64(3));
|
||||
SELECT arrayFold( x,acc -> acc+x*2, emptyArrayInt64(), toInt64(3));
|
||||
SELECT arrayFold( x,y,acc -> acc+x*2+y*3, [1, 2, 3, 4], [5, 6, 7, 8], toInt64(3));
|
||||
SELECT arrayFold( x,acc -> arrayPushBack(acc, x), [1, 2, 3, 4], emptyArrayInt64());
|
||||
SELECT arrayFold( x,acc -> arrayPushFront(acc, x), [1, 2, 3, 4], emptyArrayInt64());
|
||||
SELECT arrayFold( x,acc -> (arrayPushFront(acc.1, x),arrayPushBack(acc.2, x)), [1, 2, 3, 4], (emptyArrayInt64(), emptyArrayInt64()));
|
||||
SELECT arrayFold( x,acc -> x%2 ? (arrayPushBack(acc.1, x), acc.2): (acc.1, arrayPushBack(acc.2, x)), [1, 2, 3, 4, 5, 6], (emptyArrayInt64(), emptyArrayInt64()));
|
||||
SELECT arrayFold( acc,x -> acc+x*2, [1, 2, 3, 4], toInt64(3));
|
||||
SELECT arrayFold( acc,x -> acc+x*2, emptyArrayInt64(), toInt64(3));
|
||||
SELECT arrayFold( acc,x,y -> acc+x*2+y*3, [1, 2, 3, 4], [5, 6, 7, 8], toInt64(3));
|
||||
SELECT arrayFold( acc,x -> arrayPushBack(acc, x), [1, 2, 3, 4], emptyArrayInt64());
|
||||
SELECT arrayFold( acc,x -> arrayPushFront(acc, x), [1, 2, 3, 4], emptyArrayInt64());
|
||||
SELECT arrayFold( acc,x -> (arrayPushFront(acc.1, x),arrayPushBack(acc.2, x)), [1, 2, 3, 4], (emptyArrayInt64(), emptyArrayInt64()));
|
||||
SELECT arrayFold( acc,x -> x%2 ? (arrayPushBack(acc.1, x), acc.2): (acc.1, arrayPushBack(acc.2, x)), [1, 2, 3, 4, 5, 6], (emptyArrayInt64(), emptyArrayInt64()));
|
||||
|
||||
SELECT 'Non-const arrays';
|
||||
SELECT arrayFold( x,acc -> acc+x, range(number), number) FROM system.numbers LIMIT 5;
|
||||
SELECT arrayFold( x,acc -> arrayPushFront(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;
|
||||
SELECT arrayFold( x,acc -> x%2 ? arrayPushFront(acc,x) : arrayPushBack(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;
|
||||
SELECT arrayFold( acc,x -> acc+x, range(number), number) FROM system.numbers LIMIT 5;
|
||||
SELECT arrayFold( acc,x -> arrayPushFront(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;
|
||||
SELECT arrayFold( acc,x -> x%2 ? arrayPushFront(acc,x) : arrayPushBack(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;
|
||||
|
Loading…
Reference in New Issue
Block a user