Some fixups

This commit is contained in:
Robert Schulze 2023-10-08 20:26:18 +00:00
parent 2848548c63
commit 07e0cc196d
No known key found for this signature in database
GPG Key ID: 26703B55FB13728A
5 changed files with 93 additions and 77 deletions

View File

@ -1123,6 +1123,32 @@ Result:
└─────────────────────────────┘
```
## arrayFold
Applies a lambda function to one or more equally-sized arrays and collects the result in an accumulator.
**Syntax**
``` sql
arrayFold(lambda_function, arr1, arr2, ..., accumulator)
```
**Example**
Query:
``` sql
SELECT arrayFold( x,acc -> acc + x*2, [1, 2, 3, 4], toInt64(3)) AS res;
```
Result:
``` text
┌─arrayFold(lambda(tuple(x, acc), plus(acc, multiply(x, 2))), [1, 2, 3, 4], toInt64(3))─┐
│ 3 │
└───────────────────────────────────────────────────────────────────────────────────────┘
```
## arrayReverse(arr)
Returns an array of the same size as the original array containing the elements in reverse order.

View File

@ -14,8 +14,9 @@ namespace ErrorCodes
extern const int TYPE_MISMATCH;
}
/** arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, init_accum) - apply the expression to each element of the array (or set of parallel arrays).
*/
/**
* arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, accum_initial) - apply the expression to each element of the array (or set of arrays).
*/
class ArrayFold : public IFunction
{
public:
@ -29,21 +30,22 @@ public:
void getLambdaArgumentTypes(DataTypes & arguments) const override
{
if (arguments.size() < 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} needs lambda function, at least one array argument and one accumulator argument.", getName());
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator argument", getName());
DataTypes nested_types(arguments.size() - 1);
for (size_t i = 0; i < nested_types.size() - 1; ++i)
{
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i + 1]);
const auto * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i + 1]);
if (!array_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be array. Found {} instead.", toString(i + 2), getName(), arguments[i + 1]->getName());
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be array, found {} instead", i + 2, getName(), arguments[i + 1]->getName());
nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType());
}
nested_types[nested_types.size() - 1] = arguments[arguments.size() - 1];
const DataTypeFunction * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
const auto * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
if (!function_type || function_type->getArgumentTypes().size() != nested_types.size())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for this overload of {} must be a function with {} arguments. Found {} instead.",
getName(), toString(nested_types.size()), arguments[0]->getName());
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for this overload of {} must be a function with {} arguments, found {} instead.",
getName(), nested_types.size(), arguments[0]->getName());
arguments[0] = std::make_shared<DataTypeFunction>(nested_types);
}
@ -51,45 +53,44 @@ public:
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() < 2)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} needs at least 2 arguments; passed {}.", getName(), toString(arguments.size()));
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least 2 arguments, passed: {}.", getName(), arguments.size());
const auto * data_type_function = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
if (!data_type_function)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function.", getName());
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());
auto const accumulator_type = arguments.back().type;
auto const lambda_type = data_type_function->getReturnType();
if (! accumulator_type->equals(*lambda_type))
throw Exception(ErrorCodes::TYPE_MISMATCH, "Return type of lambda function must be the same as the accumulator type. "
"Inferred type of lambda {}, inferred type of accumulator {}.", lambda_type->getName(), accumulator_type->getName());
auto accumulator_type = arguments.back().type;
auto lambda_type = data_type_function->getReturnType();
if (!accumulator_type->equals(*lambda_type))
throw Exception(ErrorCodes::TYPE_MISMATCH,
"Return type of lambda function must be the same as the accumulator type, inferred return type of lambda: {}, inferred type of accumulator: {}",
lambda_type->getName(), accumulator_type->getName());
return DataTypePtr(accumulator_type);
return accumulator_type;
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const auto & column_with_type_and_name = arguments[0];
const auto & lambda_with_type_and_name = arguments[0];
if (!column_with_type_and_name.column)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function.", getName());
if (!lambda_with_type_and_name.column)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());
const auto * column_function = typeid_cast<const ColumnFunction *>(column_with_type_and_name.column.get());
if (!column_function)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function.", getName());
const auto * lambda_function = typeid_cast<const ColumnFunction *>(lambda_with_type_and_name.column.get());
if (!lambda_function)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());
ColumnPtr offsets_column;
ColumnPtr column_first_array_ptr;
const ColumnArray * column_first_array = nullptr;
ColumnsWithTypeAndName arrays;
arrays.reserve(arguments.size() - 1);
/// Valdate input types and get input array caolumns in convinient form
/// Validate input types and get input array columns in convenient form
for (size_t i = 1; i < arguments.size() - 1; ++i)
{
const auto & array_with_type_and_name = arguments[i];
ColumnPtr column_array_ptr = array_with_type_and_name.column;
const auto * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
const DataTypePtr & array_type_ptr = array_with_type_and_name.type;
const auto * array_type = checkAndGetDataType<DataTypeArray>(array_type_ptr.get());
if (!column_array)
{
const ColumnConst * column_const_array = checkAndGetColumnConst<ColumnArray>(column_array_ptr.get());
@ -98,12 +99,14 @@ public:
column_array_ptr = recursiveRemoveLowCardinality(column_const_array->convertToFullColumn());
column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
}
const DataTypePtr & array_type_ptr = array_with_type_and_name.type;
const auto * array_type = checkAndGetDataType<DataTypeArray>(array_type_ptr.get());
if (!array_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Expected array type, found {}", array_type_ptr->getName());
if (!offsets_column)
{
offsets_column = column_array->getOffsetsPtr();
}
else
{
/// The first condition is optimization: do not compare data if the pointers are equal.
@ -123,7 +126,7 @@ public:
ssize_t rows_count = input_rows_count;
ssize_t data_row_count = arrays[0].column->size();
auto array_count = arrays.size();
size_t array_count = arrays.size();
if (rows_count == 0)
return arguments.back().column->convertToFullColumnIfConst()->cloneEmpty();
@ -147,7 +150,8 @@ public:
/// selector[i] is an index that i_th data element has in an array it corresponds to
for (ssize_t i = 0; i < data_row_count; ++i)
{
selector[i] = cur_ind++;
selector[i] = cur_ind;
cur_ind++;
if (cur_ind > max_array_size)
max_array_size = cur_ind;
while (cur_arr < rows_count && cur_ind >= offsets[cur_arr] - offsets[cur_arr - 1])
@ -170,10 +174,9 @@ public:
IColumn::Permutation inverse_permutation(rows_count);
size_t inverse_permutation_count = 0;
/**Current_column after each iteration contains value of accumulator after aplying values under indexes ind of arrays.
*At each iteration only rows of current_column with arrays that still has unaplied elements are kept.
*Discarded rows which contain finished calculations are added to result_data column and as we insert them we save their original row_number in inverse_permutation vector
*/
/// current_column after each iteration contains value of accumulator after aplying values under indexes of arrays.
/// At each iteration only rows of current_column with arrays that still has unapplied elements are kept.
/// Discarded rows which contain finished calculations are added to result_data column and as we insert them we save their original row_number in inverse_permutation vector
for (size_t ind = 0; ind < max_array_size; ++ind)
{
IColumn::Selector prev_selector(prev_size);
@ -181,9 +184,7 @@ public:
for (ssize_t irow = 0; irow < rows_count; ++irow)
{
if (offsets[irow] - offsets[irow - 1] > ind)
{
prev_selector[prev_ind++] = 1;
}
else if (offsets[irow] - offsets[irow - 1] == ind)
{
inverse_permutation[inverse_permutation_count++] = irow;
@ -194,7 +195,7 @@ public:
result_data->insertRangeFrom(*(prev[0]), 0, prev[0]->size());
auto res_lambda = column_function->cloneResized(prev[1]->size());
auto res_lambda = lambda_function->cloneResized(prev[1]->size());
auto * res_lambda_ptr = typeid_cast<ColumnFunction *>(res_lambda.get());
for (size_t i = 0; i < array_count; i++)
@ -211,7 +212,7 @@ public:
inverse_permutation[inverse_permutation_count++] = irow;
/// We have result_data containing result for every row and inverse_permutation which contains indexes of rows in input it corresponds to.
/// Now we neead to invert inverse_permuation and apply it to result_data to get rows in right order.
/// Now we need to invert inverse_permuation and apply it to result_data to get rows in right order.
IColumn::Permutation perm(rows_count);
for (ssize_t i = 0; i < rows_count; i++)
perm[inverse_permutation[i]] = i;
@ -228,8 +229,8 @@ private:
REGISTER_FUNCTION(ArrayFold)
{
factory.registerFunction<ArrayFold>(FunctionDocumentation{.description=R"(
Function arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, init_accum) applies lambda function to a number of same sized array columns
and collects result in accumulator. Accumulator can be either constant or column.
)", .examples{{"sum", "SELECT arrayFold(x,acc -> acc + x, [1,2,3,4], toInt64(1));", "11"}}, .categories{"Array"}});
Function arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, accum_initial) applies lambda function to a number of equally-sized arrays
and collects the result in an accumulator.
)", .examples{{"sum", "SELECT arrayFold(x,acc -> acc+x, [1,2,3,4], toInt64(1));", "11"}}, .categories{"Array"}});
}
}

View File

@ -1,44 +1,25 @@
Negative tests
Const arrays
23
3
101
269
[1,2,3,4]
[4,3,2,1]
([4,3,2,1],[1,2,3,4])
([1,3,5],[2,4,6])
0
Non-const arrays
0
1
3
6
10
0
1
3
6
10
15
[]
[0]
[1,0]
[2,1,0]
[3,2,1,0]
[4,3,2,1,0]
[]
[0]
[1,0]
[1,0,2]
[3,1,0,2]
[3,1,0,2,4]
[(0,0)]
[(0,1),(0,0)]
[(1,2),(0,1),(0,0)]
[(2,3),(1,2),(0,1),(0,0)]
[(3,4),(2,3),(1,2),(0,1),(0,0)]
[(4,5),(3,4),(2,3),(1,2),(0,1),(0,0)]
[]
['0']
['0','1']
['0','1','2']
['0','1','2','3']
['0','1','2','3','4']

View File

@ -1,16 +1,23 @@
SELECT arrayFold(x,acc -> acc + x * 2, [1,2,3,4], toInt64(3));
SELECT arrayFold(x,acc -> acc + x * 2, emptyArrayInt64(), toInt64(3));
SELECT arrayFold(x,y,acc -> acc + x * 2 + y * 3, [1,2,3,4], [5,6,7,8], toInt64(3));
SELECT arrayFold(x,y,z,acc -> acc + x * 2 + y * 3 + z * 4, [1,2,3,4], [5,6,7,8], [9,10,11,12], toInt64(3));
SELECT arrayFold(x,acc -> arrayPushBack(acc,x), [1,2,3,4], emptyArrayInt64());
SELECT arrayFold(x,acc -> arrayPushFront(acc,x), [1,2,3,4], emptyArrayInt64());
SELECT arrayFold(x,acc -> (arrayPushFront(acc.1,x), arrayPushBack(acc.2,x)), [1,2,3,4], (emptyArrayInt64(), emptyArrayInt64()));
SELECT arrayFold(x,acc -> x % 2 ? (arrayPushBack(acc.1,x), acc.2): (acc.1, arrayPushBack(acc.2,x)), [1,2,3,4,5,6], (emptyArrayInt64(), emptyArrayInt64()));
SELECT 'Negative tests';
SELECT arrayFold(); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT arrayFold(1); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT arrayFold(1, toUInt64(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( x,acc -> x, emptyArrayString(), toInt8(0)); -- { serverError TYPE_MISMATCH }
SELECT arrayFold( x,acc -> x, 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( x,y,acc -> x, [0, 1], 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( x,acc -> x, [0, 1], [2, 3], toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayFold( x,y,acc -> x, [0, 1], [2, 3, 4], toUInt8(0)); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH }
SELECT arrayFold(x,acc -> acc+x, range(number), toInt64(0)) FROM system.numbers LIMIT 6;
SELECT arrayFold(x,acc -> acc+x, range(number), number) FROM system.numbers LIMIT 6;
SELECT arrayFold(x,acc -> arrayPushFront(acc, x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 6;
SELECT arrayFold(x,acc -> x % 2 ? arrayPushFront(acc, x) : arrayPushBack(acc, x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 6;
SELECT arrayFold(x,acc -> arrayPushFront(acc, (x, x+1)), range(number), [(toUInt64(0),toUInt64(0))]) FROM system.numbers LIMIT 6;
SELECT arrayFold(x, acc -> concat(acc, arrayMap(z -> toString(x), [number])) , range(number), CAST([] as Array(String))) FROM system.numbers LIMIT 6;
SELECT 'Const arrays';
SELECT arrayFold( x,acc -> acc+x*2, [1, 2, 3, 4], toInt64(3));
SELECT arrayFold( x,acc -> acc+x*2, emptyArrayInt64(), toInt64(3));
SELECT arrayFold( x,y,acc -> acc+x*2+y*3, [1, 2, 3, 4], [5, 6, 7, 8], toInt64(3));
SELECT arrayFold( x,acc -> arrayPushBack(acc, x), [1, 2, 3, 4], emptyArrayInt64());
SELECT arrayFold( x,acc -> arrayPushFront(acc, x), [1, 2, 3, 4], emptyArrayInt64());
SELECT arrayFold( x,acc -> (arrayPushFront(acc.1, x),arrayPushBack(acc.2, x)), [1, 2, 3, 4], (emptyArrayInt64(), emptyArrayInt64()));
SELECT arrayFold( x,acc -> x%2 ? (arrayPushBack(acc.1, x), acc.2): (acc.1, arrayPushBack(acc.2, x)), [1, 2, 3, 4, 5, 6], (emptyArrayInt64(), emptyArrayInt64()));
SELECT 'Non-const arrays';
SELECT arrayFold( x,acc -> acc+x, range(number), number) FROM system.numbers LIMIT 5;
SELECT arrayFold( x,acc -> arrayPushFront(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;
SELECT arrayFold( x,acc -> x%2 ? arrayPushFront(acc,x) : arrayPushBack(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;

View File

@ -1046,6 +1046,7 @@ arrayFilter
arrayFirst
arrayFirstIndex
arrayFlatten
arrayFold
arrayIntersect
arrayJaccardIndex
arrayJoin