Revert "Function arrayFold for folding over array with accumulator"

This commit is contained in:
alexey-milovidov 2021-04-18 03:34:05 +03:00 committed by GitHub
parent d492beea9f
commit 83038f84af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 0 additions and 423 deletions

View File

@ -1213,62 +1213,6 @@ SELECT arrayFill(x -> not isNull(x), [1, null, 3, 11, 12, null, null, 5, 6, 14,
Note that the `arrayFill` is a [higher-order function](../../sql-reference/functions/index.md#higher-order-functions). You must pass a lambda function to it as the first argument, and it cant be omitted. Note that the `arrayFill` is a [higher-order function](../../sql-reference/functions/index.md#higher-order-functions). You must pass a lambda function to it as the first argument, and it cant be omitted.
## arrayFold(func, arr1, …, init) {#array-fold}
Returns an result of [folding](https://en.wikipedia.org/wiki/Fold_(higher-order_function)) arrays and value `init` using function `func`.
I.e. result of calculation `func(arr1[n], …, func(arr1[n - 1], …, func(…, func(arr1[2], …, func(arr1[1], …, init)))))`.
Note that the `arrayMap` is a [higher-order function](../../sql-reference/functions/index.md#higher-order-functions). You must pass a lambda function to it as the first argument, and it cant be omitted.
**Arguments**
- `func` — The lambda function with `n+1` arguments (where `n` is number of input arrays), first `n` arguments are for
current elements of input arrays, and last argument is for current value of accumulator.
- `arr` — Any number of [arrays](../../sql-reference/data-types/array.md).
- `init` - Initial value of accumulator.
**Returned value**
Final value of accumulator.
**Examples**
The following example shows how to acquire product and sum of elements of array:
``` sql
SELECT arrayMap(x, accum -> (accum.1 * x, accum.2 + x), [1, 2, 3], (0, 1)) as res;
```
``` text
┌─res───────┐
│ (120, 15) │
└───────────┘
```
The following example shows how to reverse elements of array:
``` sql
SELECT arrayFold(x, acc -> arrayPushFront(acc, x), [1,2,3,4,5], emptyArrayUInt64()) as res;
```
``` text
┌─res─────────┐
│ [5,4,3,2,1] │
└─────────────┘
```
Folding may be used to access of already passed elements due to function calculation, for example:
``` sql
SELECT arrayFold(x, acc -> (x, concat(acc.2, toString(acc.1), ',')), [1,2], (0,''))
```
``` text
┌─res────────┐
│ (2,'0,1,') │
└────────────┘
```
## arrayReverseFill(func, arr1, …) {#array-reverse-fill} ## arrayReverseFill(func, arr1, …) {#array-reverse-fill}
Scan through `arr1` from the last element to the first element and replace `arr1[i]` by `arr1[i + 1]` if `func` returns 0. The last element of `arr1` will not be replaced. Scan through `arr1` from the last element to the first element and replace `arr1[i]` by `arr1[i + 1]` if `func` returns 0. The last element of `arr1` will not be replaced.

View File

@ -1147,62 +1147,6 @@ SELECT arrayReverseFill(x -> not isNull(x), [1, null, 3, 11, 12, null, null, 5,
Функция `arrayReverseFill` является [функцией высшего порядка](../../sql-reference/functions/index.md#higher-order-functions) — в качестве первого аргумента ей нужно передать лямбда-функцию, и этот аргумент не может быть опущен. Функция `arrayReverseFill` является [функцией высшего порядка](../../sql-reference/functions/index.md#higher-order-functions) — в качестве первого аргумента ей нужно передать лямбда-функцию, и этот аргумент не может быть опущен.
## arrayFold(func, arr1, …, init) {#array-fold}
Возвращает результат [сворачивания](https://ru.wikipedia.org/wiki/%D0%A1%D0%B2%D1%91%D1%80%D1%82%D0%BA%D0%B0_%D1%81%D0%BF%D0%B8%D1%81%D0%BA%D0%B0) массивов и начального значения `init` с помощью функции `func`.
Т.е. результат вычисления `func(arr1[n], …, func(arr1[n - 1], …, func(…, func(arr1[2], …, func(arr1[1], …, init)))))`.
Функция `arrayFold` является [функцией высшего порядка](../../sql-reference/functions/index.md#higher-order-functions) — в качестве первого аргумента ей нужно передать лямбда-функцию, и этот аргумент не может быть опущен.
**Аргументы**
- `func` — лямбда-функция с `n+1` параметром (где `n` это количество входных массивов), причём первые `n` параметров
используются для текущих элементов входных массивов, а последний элемент для текущего значения аккумулятора.
- `arr` — произвольное количество [массивов](../../sql-reference/data-types/array.md).
- `init` - начальное значение аккумулятора.
**Возвращаемое значение**
Итоговое значение аккумулятора.
**Примеры**
Следующий пример показывает, как вычислить произведение и сумму элементов массива:
``` sql
SELECT arrayMap(x, accum -> (accum.1 * x, accum.2 + x), [1, 2, 3], (0, 1)) as res;
```
``` text
┌─res───────┐
│ (120, 15) │
└───────────┘
```
В этом примере показано, как обратить массив:
``` sql
SELECT arrayFold(x, acc -> arrayPushFront(acc, x), [1,2,3,4,5], emptyArrayUInt64()) as res;
```
``` text
┌─res─────────┐
│ [5,4,3,2,1] │
└─────────────┘
```
Свёртка может быть использована для доступа к уже пройденным в процессе вычисления элементам. Например:
``` sql
SELECT arrayFold(x, acc -> (x, concat(acc.2, toString(acc.1), ',')), [1,2], (0,''))
```
``` text
┌─res────────┐
│ (2,'0,1,') │
└────────────┘
```
## arraySplit(func, arr1, …) {#array-split} ## arraySplit(func, arr1, …) {#array-split}
Разделяет массив `arr1` на несколько. Если `func` возвращает не 0, то массив разделяется, а элемент помещается в левую часть. Массив не разбивается по первому элементу. Разделяет массив `arr1` на несколько. Если `func` возвращает не 0, то массив разделяется, а элемент помещается в левую часть. Массив не разбивается по первому элементу.
@ -1239,7 +1183,6 @@ SELECT arrayReverseSplit((x, y) -> y, [1, 2, 3, 4, 5], [1, 0, 0, 1, 0]) AS res
Функция `arrayReverseSplit` является [функцией высшего порядка](../../sql-reference/functions/index.md#higher-order-functions) — в качестве первого аргумента ей нужно передать лямбда-функцию, и этот аргумент не может быть опущен. Функция `arrayReverseSplit` является [функцией высшего порядка](../../sql-reference/functions/index.md#higher-order-functions) — в качестве первого аргумента ей нужно передать лямбда-функцию, и этот аргумент не может быть опущен.
## arrayExists(\[func,\] arr1, …) {#arrayexistsfunc-arr1} ## arrayExists(\[func,\] arr1, …) {#arrayexistsfunc-arr1}
Возвращает 1, если существует хотя бы один элемент массива `arr`, для которого функция func возвращает не 0. Иначе возвращает 0. Возвращает 1, если существует хотя бы один элемент массива `arr`, для которого функция func возвращает не 0. Иначе возвращает 0.

View File

@ -1,187 +0,0 @@
#include "FunctionArrayMapped.h"
#include <Functions/FunctionFactory.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
extern const int TYPE_MISMATCH;
}
/** arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, init_accum) - apply the expression to each element of the array (or set of parallel arrays).
*/
class FunctionArrayFold : public IFunction
{
public:
static constexpr auto name = "arrayFold";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayFold>(); }
String getName() const override { return name; }
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
void getLambdaArgumentTypes(DataTypes & arguments) const override
{
if (arguments.size() < 3)
throw Exception("Function " + getName() + " needs lambda function, at least one array argument and one accumulator argument.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
DataTypes nested_types(arguments.size() - 1);
for (size_t i = 0; i < nested_types.size() - 1; ++i)
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i + 1]);
if (!array_type)
throw Exception("Argument " + toString(i + 2) + " of function " + getName() + " must be array. Found "
+ arguments[i + 1]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType());
}
nested_types[nested_types.size() - 1] = arguments[arguments.size() - 1];
const DataTypeFunction * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
if (!function_type || function_type->getArgumentTypes().size() != nested_types.size())
throw Exception("First argument for this overload of " + getName() + " must be a function with "
+ toString(nested_types.size()) + " arguments. Found "
+ arguments[0]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
arguments[0] = std::make_shared<DataTypeFunction>(nested_types);
}
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() < 2)
throw Exception("Function " + getName() + " needs at least 2 arguments; passed "
+ toString(arguments.size()) + ".",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
const auto * data_type_function = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
if (!data_type_function)
throw Exception("First argument for function " + getName() + " must be a function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
auto const accumulator_type = arguments.back().type;
auto const lambda_type = data_type_function->getReturnType();
if (! accumulator_type->equals(*lambda_type))
throw Exception("Return type of lambda function must be the same as the accumulator type. "
"Inferred type of lambda " + lambda_type->getName() + ", "
+ "inferred type of accumulator " + accumulator_type->getName() + ".",
ErrorCodes::TYPE_MISMATCH);
return DataTypePtr(accumulator_type);
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override
{
const auto & column_with_type_and_name = arguments[0];
if (!column_with_type_and_name.column)
throw Exception("First argument for function " + getName() + " must be a function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
const auto * column_function = typeid_cast<const ColumnFunction *>(column_with_type_and_name.column.get());
if (!column_function)
throw Exception("First argument for function " + getName() + " must be a function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
ColumnPtr offsets_column;
ColumnPtr column_first_array_ptr;
const ColumnArray * column_first_array = nullptr;
ColumnsWithTypeAndName arrays;
arrays.reserve(arguments.size() - 1);
for (size_t i = 1; i < arguments.size() - 1; ++i)
{
const auto & array_with_type_and_name = arguments[i];
ColumnPtr column_array_ptr = array_with_type_and_name.column;
const auto * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
const DataTypePtr & array_type_ptr = array_with_type_and_name.type;
const auto * array_type = checkAndGetDataType<DataTypeArray>(array_type_ptr.get());
if (!column_array)
{
const ColumnConst * column_const_array = checkAndGetColumnConst<ColumnArray>(column_array_ptr.get());
if (!column_const_array)
throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN);
column_array_ptr = recursiveRemoveLowCardinality(column_const_array->convertToFullColumn());
column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
}
if (!array_type)
throw Exception("Expected array type, found " + array_type_ptr->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (!offsets_column)
{
offsets_column = column_array->getOffsetsPtr();
}
else
{
/// The first condition is optimization: do not compare data if the pointers are equal.
if (column_array->getOffsetsPtr() != offsets_column
&& column_array->getOffsets() != typeid_cast<const ColumnArray::ColumnOffsets &>(*offsets_column).getData())
throw Exception("Arrays passed to " + getName() + " must have equal size", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
}
if (i == 1)
{
column_first_array_ptr = column_array_ptr;
column_first_array = column_array;
}
arrays.emplace_back(ColumnWithTypeAndName(column_array->getDataPtr(),
recursiveRemoveLowCardinality(array_type->getNestedType()),
array_with_type_and_name.name));
}
arrays.emplace_back(arguments.back());
MutableColumnPtr result = arguments.back().column->convertToFullColumnIfConst()->cloneEmpty();
size_t arr_cursor = 0;
for (size_t irow = 0; irow < column_first_array->size(); ++irow) // for each row of result
{
// Make accumulator column for this row. We initialize it
// with the starting value given as the last argument.
ColumnWithTypeAndName accumulator_column = arguments.back();
ColumnPtr acc(accumulator_column.column->cut(irow, 1));
auto accumulator = ColumnWithTypeAndName(acc,
accumulator_column.type,
accumulator_column.name);
ColumnPtr res(acc);
size_t const arr_next = column_first_array->getOffsets()[irow]; // when we do folding
for (size_t iter = 0; arr_cursor < arr_next; ++iter, ++arr_cursor)
{
// Make slice of input arrays and accumulator for lambda
ColumnsWithTypeAndName iter_arrays;
iter_arrays.reserve(arrays.size() + 1);
for (size_t icolumn = 0; icolumn < arrays.size() - 1; ++icolumn)
{
auto const & arr = arrays[icolumn];
iter_arrays.emplace_back(ColumnWithTypeAndName(arr.column->cut(arr_cursor, 1),
arr.type,
arr.name));
}
iter_arrays.emplace_back(accumulator);
// Calculate function on arguments
auto replicated_column_function_ptr = IColumn::mutate(column_function->replicate(ColumnArray::Offsets(column_first_array->getOffsets().size(), 1)));
auto * replicated_column_function = typeid_cast<ColumnFunction *>(replicated_column_function_ptr.get());
replicated_column_function->appendArguments(iter_arrays);
auto lambda_result = replicated_column_function->reduce().column;
if (lambda_result->lowCardinality())
lambda_result = lambda_result->convertToFullColumnIfLowCardinality();
res = lambda_result->cut(0, 1);
accumulator.column = res;
}
result->insert((*res)[0]);
}
return result;
}
};
void registerFunctionArrayFold(FunctionFactory & factory)
{
factory.registerFunction<FunctionArrayFold>();
}
}

View File

@ -4,7 +4,6 @@ namespace DB
class FunctionFactory; class FunctionFactory;
void registerFunctionArrayMap(FunctionFactory & factory); void registerFunctionArrayMap(FunctionFactory & factory);
void registerFunctionArrayFold(FunctionFactory & factory);
void registerFunctionArrayFilter(FunctionFactory & factory); void registerFunctionArrayFilter(FunctionFactory & factory);
void registerFunctionArrayCount(FunctionFactory & factory); void registerFunctionArrayCount(FunctionFactory & factory);
void registerFunctionArrayExists(FunctionFactory & factory); void registerFunctionArrayExists(FunctionFactory & factory);
@ -23,7 +22,6 @@ void registerFunctionArrayDifference(FunctionFactory & factory);
void registerFunctionsHigherOrder(FunctionFactory & factory) void registerFunctionsHigherOrder(FunctionFactory & factory)
{ {
registerFunctionArrayMap(factory); registerFunctionArrayMap(factory);
registerFunctionArrayFold(factory);
registerFunctionArrayFilter(factory); registerFunctionArrayFilter(factory);
registerFunctionArrayCount(factory); registerFunctionArrayCount(factory);
registerFunctionArrayExists(factory); registerFunctionArrayExists(factory);

View File

@ -144,7 +144,6 @@ SRCS(
array/arrayFirst.cpp array/arrayFirst.cpp
array/arrayFirstIndex.cpp array/arrayFirstIndex.cpp
array/arrayFlatten.cpp array/arrayFlatten.cpp
array/arrayFold.cpp
array/arrayIntersect.cpp array/arrayIntersect.cpp
array/arrayJoin.cpp array/arrayJoin.cpp
array/arrayMap.cpp array/arrayMap.cpp

View File

@ -1,4 +0,0 @@
<test>
<query>SELECT arrayFold(x, acc -> acc + 1, range(100000), toUInt64(0))</query> <!-- count -->
<query>SELECT arrayFold(x, acc -> acc + x, range(100000), toUInt64(0))</query> <!-- sum -->
</test>

View File

@ -1,8 +0,0 @@
23
3
101
269
[1,2,3,4]
[4,3,2,1]
([4,3,2,1],[1,2,3,4])
([1,3,5],[2,4,6])

View File

@ -1,8 +0,0 @@
SELECT arrayFold(x,acc -> acc + x * 2, [1,2,3,4], toInt64(3));
SELECT arrayFold(x,acc -> acc + x * 2, emptyArrayInt64(), toInt64(3));
SELECT arrayFold(x,y,acc -> acc + x * 2 + y * 3, [1,2,3,4], [5,6,7,8], toInt64(3));
SELECT arrayFold(x,y,z,acc -> acc + x * 2 + y * 3 + z * 4, [1,2,3,4], [5,6,7,8], [9,10,11,12], toInt64(3));
SELECT arrayFold(x,acc -> arrayPushBack(acc,x), [1,2,3,4], emptyArrayInt64());
SELECT arrayFold(x,acc -> arrayPushFront(acc,x), [1,2,3,4], emptyArrayInt64());
SELECT arrayFold(x,acc -> (arrayPushFront(acc.1,x), arrayPushBack(acc.2,x)), [1,2,3,4], (emptyArrayInt64(), emptyArrayInt64()));
SELECT arrayFold(x,acc -> x % 2 ? (arrayPushBack(acc.1,x), acc.2): (acc.1, arrayPushBack(acc.2,x)), [1,2,3,4,5,6], (emptyArrayInt64(), emptyArrayInt64()));

View File

@ -1,80 +0,0 @@
0
0
1
3
6
10
15
21
28
36
0
1
3
6
10
15
21
28
36
45
[]
[0]
[1,0]
[2,1,0]
[3,2,1,0]
[4,3,2,1,0]
[5,4,3,2,1,0]
[6,5,4,3,2,1,0]
[7,6,5,4,3,2,1,0]
[8,7,6,5,4,3,2,1,0]
[]
[0]
[1,0]
[1,0,2]
[3,1,0,2]
[3,1,0,2,4]
[5,3,1,0,2,4]
[5,3,1,0,2,4,6]
[7,5,3,1,0,2,4,6]
[7,5,3,1,0,2,4,6,8]
(0,0)
(0,0)
(1,-1)
(3,-3)
(6,-6)
(10,-10)
(15,-15)
(21,-21)
(28,-28)
(36,-36)
(0,0)
(0,0)
(1,-1)
(3,-3)
(6,-6)
(10,-10)
(15,-15)
(21,-21)
(28,-28)
(36,-36)
[(0,0)]
[(0,1),(0,0)]
[(1,2),(0,1),(0,0)]
[(2,3),(1,2),(0,1),(0,0)]
[(3,4),(2,3),(1,2),(0,1),(0,0)]
[(4,5),(3,4),(2,3),(1,2),(0,1),(0,0)]
[(5,6),(4,5),(3,4),(2,3),(1,2),(0,1),(0,0)]
[(6,7),(5,6),(4,5),(3,4),(2,3),(1,2),(0,1),(0,0)]
[(7,8),(6,7),(5,6),(4,5),(3,4),(2,3),(1,2),(0,1),(0,0)]
[(8,9),(7,8),(6,7),(5,6),(4,5),(3,4),(2,3),(1,2),(0,1),(0,0)]
[]
['0']
['0','1']
['0','1','2']
['0','1','2','3']
['0','1','2','3','4']
['0','1','2','3','4','5']
['0','1','2','3','4','5','6']
['0','1','2','3','4','5','6','7']
['0','1','2','3','4','5','6','7','8']

View File

@ -1,8 +0,0 @@
SELECT arrayFold(x,acc -> acc+x, range(number), toInt64(0)) FROM system.numbers LIMIT 10;
SELECT arrayFold(x,acc -> acc+x, range(number), number) FROM system.numbers LIMIT 10;
SELECT arrayFold(x,acc -> arrayPushFront(acc, x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 10;
SELECT arrayFold(x,acc -> x % 2 ? arrayPushFront(acc, x) : arrayPushBack(acc, x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 10;
SELECT arrayFold(x,acc -> (acc.1+x, acc.2-x), range(number), (toInt64(0), toInt64(0))) FROM system.numbers LIMIT 10;
SELECT arrayFold(x,acc -> (acc.1+x.1, acc.2-x.2), arrayZip(range(number), range(number)), (toInt64(0), toInt64(0))) FROM system.numbers LIMIT 10;
SELECT arrayFold(x,acc -> arrayPushFront(acc, (x, x+1)), range(number), [(toUInt64(0),toUInt64(0))]) FROM system.numbers LIMIT 10;
SELECT arrayFold(x, acc -> concat(acc, arrayMap(z -> toString(x), [number])) , range(number), CAST([] as Array(String))) FROM system.numbers LIMIT 10;

View File

@ -1,12 +0,0 @@
SELECT arrayFold([]); -- { serverError 42 }
SELECT arrayFold([1,2,3]); -- { serverError 42 }
SELECT arrayFold([1,2,3], [4,5,6]); -- { serverError 43 }
SELECT arrayFold(1234); -- { serverError 42 }
SELECT arrayFold(x, acc -> acc + x, 10, 20); -- { serverError 43 }
SELECT arrayFold(x, acc -> acc + x, 10, [20, 30, 40]); -- { serverError 43 }
SELECT arrayFold(x -> x * 2, [1,2,3,4], toInt64(3)); -- { serverError 43 }
SELECT arrayFold(x,acc -> acc+x, number, toInt64(0)) FROM system.numbers LIMIT 10; -- { serverError 43 }
SELECT arrayFold(x,y,acc -> acc + x * 2 + y * 3, [1,2,3,4], [5,6,7], toInt64(3)); -- { serverError 190 }
SELECT arrayFold(x,acc -> acc + x * 2 + y * 3, [1,2,3,4], [5,6,7,8], toInt64(3)); -- { serverError 47 }
SELECT arrayFold(x,acc -> acc + x * 2, [1,2,3,4], [5,6,7,8], toInt64(3)); -- { serverError 43 }
SELECT arrayFold(x,acc -> concat(acc,', ', x), [1, 2, 3, 4], '0') -- { serverError 44 }