mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-10 01:25:21 +00:00
Some fixups
This commit is contained in:
parent
2848548c63
commit
07e0cc196d
@ -1123,6 +1123,32 @@ Result:
|
||||
└─────────────────────────────┘
|
||||
```
|
||||
|
||||
## arrayFold
|
||||
|
||||
Applies a lambda function to one or more equally-sized arrays and collects the result in an accumulator.
|
||||
|
||||
**Syntax**
|
||||
|
||||
``` sql
|
||||
arrayFold(lambda_function, arr1, arr2, ..., accumulator)
|
||||
```
|
||||
|
||||
**Example**
|
||||
|
||||
Query:
|
||||
|
||||
``` sql
|
||||
SELECT arrayFold( x,acc -> acc + x*2, [1, 2, 3, 4], toInt64(3)) AS res;
|
||||
```
|
||||
|
||||
Result:
|
||||
|
||||
``` text
|
||||
┌─arrayFold(lambda(tuple(x, acc), plus(acc, multiply(x, 2))), [1, 2, 3, 4], toInt64(3))─┐
|
||||
│ 3 │
|
||||
└───────────────────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## arrayReverse(arr)
|
||||
|
||||
Returns an array of the same size as the original array containing the elements in reverse order.
|
||||
|
@ -14,8 +14,9 @@ namespace ErrorCodes
|
||||
extern const int TYPE_MISMATCH;
|
||||
}
|
||||
|
||||
/** arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, init_accum) - apply the expression to each element of the array (or set of parallel arrays).
|
||||
*/
|
||||
/**
|
||||
* arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, accum_initial) - apply the expression to each element of the array (or set of arrays).
|
||||
*/
|
||||
class ArrayFold : public IFunction
|
||||
{
|
||||
public:
|
||||
@ -29,21 +30,22 @@ public:
|
||||
void getLambdaArgumentTypes(DataTypes & arguments) const override
|
||||
{
|
||||
if (arguments.size() < 3)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} needs lambda function, at least one array argument and one accumulator argument.", getName());
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires as arguments a lambda function, at least one array and an accumulator argument", getName());
|
||||
|
||||
DataTypes nested_types(arguments.size() - 1);
|
||||
for (size_t i = 0; i < nested_types.size() - 1; ++i)
|
||||
{
|
||||
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i + 1]);
|
||||
const auto * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i + 1]);
|
||||
if (!array_type)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be array. Found {} instead.", toString(i + 2), getName(), arguments[i + 1]->getName());
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be array, found {} instead", i + 2, getName(), arguments[i + 1]->getName());
|
||||
nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType());
|
||||
}
|
||||
nested_types[nested_types.size() - 1] = arguments[arguments.size() - 1];
|
||||
|
||||
const DataTypeFunction * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
|
||||
const auto * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
|
||||
if (!function_type || function_type->getArgumentTypes().size() != nested_types.size())
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for this overload of {} must be a function with {} arguments. Found {} instead.",
|
||||
getName(), toString(nested_types.size()), arguments[0]->getName());
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for this overload of {} must be a function with {} arguments, found {} instead.",
|
||||
getName(), nested_types.size(), arguments[0]->getName());
|
||||
|
||||
arguments[0] = std::make_shared<DataTypeFunction>(nested_types);
|
||||
}
|
||||
@ -51,45 +53,44 @@ public:
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
|
||||
{
|
||||
if (arguments.size() < 2)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} needs at least 2 arguments; passed {}.", getName(), toString(arguments.size()));
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least 2 arguments, passed: {}.", getName(), arguments.size());
|
||||
|
||||
const auto * data_type_function = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
|
||||
if (!data_type_function)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function.", getName());
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());
|
||||
|
||||
auto const accumulator_type = arguments.back().type;
|
||||
auto const lambda_type = data_type_function->getReturnType();
|
||||
if (! accumulator_type->equals(*lambda_type))
|
||||
throw Exception(ErrorCodes::TYPE_MISMATCH, "Return type of lambda function must be the same as the accumulator type. "
|
||||
"Inferred type of lambda {}, inferred type of accumulator {}.", lambda_type->getName(), accumulator_type->getName());
|
||||
auto accumulator_type = arguments.back().type;
|
||||
auto lambda_type = data_type_function->getReturnType();
|
||||
if (!accumulator_type->equals(*lambda_type))
|
||||
throw Exception(ErrorCodes::TYPE_MISMATCH,
|
||||
"Return type of lambda function must be the same as the accumulator type, inferred return type of lambda: {}, inferred type of accumulator: {}",
|
||||
lambda_type->getName(), accumulator_type->getName());
|
||||
|
||||
return DataTypePtr(accumulator_type);
|
||||
return accumulator_type;
|
||||
}
|
||||
|
||||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
|
||||
{
|
||||
const auto & column_with_type_and_name = arguments[0];
|
||||
const auto & lambda_with_type_and_name = arguments[0];
|
||||
|
||||
if (!column_with_type_and_name.column)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function.", getName());
|
||||
if (!lambda_with_type_and_name.column)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());
|
||||
|
||||
const auto * column_function = typeid_cast<const ColumnFunction *>(column_with_type_and_name.column.get());
|
||||
|
||||
if (!column_function)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function.", getName());
|
||||
const auto * lambda_function = typeid_cast<const ColumnFunction *>(lambda_with_type_and_name.column.get());
|
||||
if (!lambda_function)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());
|
||||
|
||||
ColumnPtr offsets_column;
|
||||
ColumnPtr column_first_array_ptr;
|
||||
const ColumnArray * column_first_array = nullptr;
|
||||
ColumnsWithTypeAndName arrays;
|
||||
arrays.reserve(arguments.size() - 1);
|
||||
/// Valdate input types and get input array caolumns in convinient form
|
||||
/// Validate input types and get input array columns in convenient form
|
||||
for (size_t i = 1; i < arguments.size() - 1; ++i)
|
||||
{
|
||||
const auto & array_with_type_and_name = arguments[i];
|
||||
ColumnPtr column_array_ptr = array_with_type_and_name.column;
|
||||
const auto * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
|
||||
const DataTypePtr & array_type_ptr = array_with_type_and_name.type;
|
||||
const auto * array_type = checkAndGetDataType<DataTypeArray>(array_type_ptr.get());
|
||||
if (!column_array)
|
||||
{
|
||||
const ColumnConst * column_const_array = checkAndGetColumnConst<ColumnArray>(column_array_ptr.get());
|
||||
@ -98,12 +99,14 @@ public:
|
||||
column_array_ptr = recursiveRemoveLowCardinality(column_const_array->convertToFullColumn());
|
||||
column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
|
||||
}
|
||||
|
||||
const DataTypePtr & array_type_ptr = array_with_type_and_name.type;
|
||||
const auto * array_type = checkAndGetDataType<DataTypeArray>(array_type_ptr.get());
|
||||
if (!array_type)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Expected array type, found {}", array_type_ptr->getName());
|
||||
|
||||
if (!offsets_column)
|
||||
{
|
||||
offsets_column = column_array->getOffsetsPtr();
|
||||
}
|
||||
else
|
||||
{
|
||||
/// The first condition is optimization: do not compare data if the pointers are equal.
|
||||
@ -123,7 +126,7 @@ public:
|
||||
|
||||
ssize_t rows_count = input_rows_count;
|
||||
ssize_t data_row_count = arrays[0].column->size();
|
||||
auto array_count = arrays.size();
|
||||
size_t array_count = arrays.size();
|
||||
|
||||
if (rows_count == 0)
|
||||
return arguments.back().column->convertToFullColumnIfConst()->cloneEmpty();
|
||||
@ -147,7 +150,8 @@ public:
|
||||
/// selector[i] is an index that i_th data element has in an array it corresponds to
|
||||
for (ssize_t i = 0; i < data_row_count; ++i)
|
||||
{
|
||||
selector[i] = cur_ind++;
|
||||
selector[i] = cur_ind;
|
||||
cur_ind++;
|
||||
if (cur_ind > max_array_size)
|
||||
max_array_size = cur_ind;
|
||||
while (cur_arr < rows_count && cur_ind >= offsets[cur_arr] - offsets[cur_arr - 1])
|
||||
@ -170,10 +174,9 @@ public:
|
||||
IColumn::Permutation inverse_permutation(rows_count);
|
||||
size_t inverse_permutation_count = 0;
|
||||
|
||||
/**Current_column after each iteration contains value of accumulator after aplying values under indexes ind of arrays.
|
||||
*At each iteration only rows of current_column with arrays that still has unaplied elements are kept.
|
||||
*Discarded rows which contain finished calculations are added to result_data column and as we insert them we save their original row_number in inverse_permutation vector
|
||||
*/
|
||||
/// current_column after each iteration contains value of accumulator after aplying values under indexes of arrays.
|
||||
/// At each iteration only rows of current_column with arrays that still has unapplied elements are kept.
|
||||
/// Discarded rows which contain finished calculations are added to result_data column and as we insert them we save their original row_number in inverse_permutation vector
|
||||
for (size_t ind = 0; ind < max_array_size; ++ind)
|
||||
{
|
||||
IColumn::Selector prev_selector(prev_size);
|
||||
@ -181,9 +184,7 @@ public:
|
||||
for (ssize_t irow = 0; irow < rows_count; ++irow)
|
||||
{
|
||||
if (offsets[irow] - offsets[irow - 1] > ind)
|
||||
{
|
||||
prev_selector[prev_ind++] = 1;
|
||||
}
|
||||
else if (offsets[irow] - offsets[irow - 1] == ind)
|
||||
{
|
||||
inverse_permutation[inverse_permutation_count++] = irow;
|
||||
@ -194,7 +195,7 @@ public:
|
||||
|
||||
result_data->insertRangeFrom(*(prev[0]), 0, prev[0]->size());
|
||||
|
||||
auto res_lambda = column_function->cloneResized(prev[1]->size());
|
||||
auto res_lambda = lambda_function->cloneResized(prev[1]->size());
|
||||
auto * res_lambda_ptr = typeid_cast<ColumnFunction *>(res_lambda.get());
|
||||
|
||||
for (size_t i = 0; i < array_count; i++)
|
||||
@ -211,7 +212,7 @@ public:
|
||||
inverse_permutation[inverse_permutation_count++] = irow;
|
||||
|
||||
/// We have result_data containing result for every row and inverse_permutation which contains indexes of rows in input it corresponds to.
|
||||
/// Now we neead to invert inverse_permuation and apply it to result_data to get rows in right order.
|
||||
/// Now we need to invert inverse_permuation and apply it to result_data to get rows in right order.
|
||||
IColumn::Permutation perm(rows_count);
|
||||
for (ssize_t i = 0; i < rows_count; i++)
|
||||
perm[inverse_permutation[i]] = i;
|
||||
@ -228,8 +229,8 @@ private:
|
||||
REGISTER_FUNCTION(ArrayFold)
|
||||
{
|
||||
factory.registerFunction<ArrayFold>(FunctionDocumentation{.description=R"(
|
||||
Function arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, init_accum) applies lambda function to a number of same sized array columns
|
||||
and collects result in accumulator. Accumulator can be either constant or column.
|
||||
)", .examples{{"sum", "SELECT arrayFold(x,acc -> acc + x, [1,2,3,4], toInt64(1));", "11"}}, .categories{"Array"}});
|
||||
Function arrayFold(x1,...,xn,accum -> expression, array1,...,arrayn, accum_initial) applies lambda function to a number of equally-sized arrays
|
||||
and collects the result in an accumulator.
|
||||
)", .examples{{"sum", "SELECT arrayFold(x,acc -> acc+x, [1,2,3,4], toInt64(1));", "11"}}, .categories{"Array"}});
|
||||
}
|
||||
}
|
||||
|
@ -1,44 +1,25 @@
|
||||
Negative tests
|
||||
Const arrays
|
||||
23
|
||||
3
|
||||
101
|
||||
269
|
||||
[1,2,3,4]
|
||||
[4,3,2,1]
|
||||
([4,3,2,1],[1,2,3,4])
|
||||
([1,3,5],[2,4,6])
|
||||
0
|
||||
Non-const arrays
|
||||
0
|
||||
1
|
||||
3
|
||||
6
|
||||
10
|
||||
0
|
||||
1
|
||||
3
|
||||
6
|
||||
10
|
||||
15
|
||||
[]
|
||||
[0]
|
||||
[1,0]
|
||||
[2,1,0]
|
||||
[3,2,1,0]
|
||||
[4,3,2,1,0]
|
||||
[]
|
||||
[0]
|
||||
[1,0]
|
||||
[1,0,2]
|
||||
[3,1,0,2]
|
||||
[3,1,0,2,4]
|
||||
[(0,0)]
|
||||
[(0,1),(0,0)]
|
||||
[(1,2),(0,1),(0,0)]
|
||||
[(2,3),(1,2),(0,1),(0,0)]
|
||||
[(3,4),(2,3),(1,2),(0,1),(0,0)]
|
||||
[(4,5),(3,4),(2,3),(1,2),(0,1),(0,0)]
|
||||
[]
|
||||
['0']
|
||||
['0','1']
|
||||
['0','1','2']
|
||||
['0','1','2','3']
|
||||
['0','1','2','3','4']
|
||||
|
@ -1,16 +1,23 @@
|
||||
SELECT arrayFold(x,acc -> acc + x * 2, [1,2,3,4], toInt64(3));
|
||||
SELECT arrayFold(x,acc -> acc + x * 2, emptyArrayInt64(), toInt64(3));
|
||||
SELECT arrayFold(x,y,acc -> acc + x * 2 + y * 3, [1,2,3,4], [5,6,7,8], toInt64(3));
|
||||
SELECT arrayFold(x,y,z,acc -> acc + x * 2 + y * 3 + z * 4, [1,2,3,4], [5,6,7,8], [9,10,11,12], toInt64(3));
|
||||
SELECT arrayFold(x,acc -> arrayPushBack(acc,x), [1,2,3,4], emptyArrayInt64());
|
||||
SELECT arrayFold(x,acc -> arrayPushFront(acc,x), [1,2,3,4], emptyArrayInt64());
|
||||
SELECT arrayFold(x,acc -> (arrayPushFront(acc.1,x), arrayPushBack(acc.2,x)), [1,2,3,4], (emptyArrayInt64(), emptyArrayInt64()));
|
||||
SELECT arrayFold(x,acc -> x % 2 ? (arrayPushBack(acc.1,x), acc.2): (acc.1, arrayPushBack(acc.2,x)), [1,2,3,4,5,6], (emptyArrayInt64(), emptyArrayInt64()));
|
||||
SELECT 'Negative tests';
|
||||
SELECT arrayFold(); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
|
||||
SELECT arrayFold(1); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
|
||||
SELECT arrayFold(1, toUInt64(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayFold( x,acc -> x, emptyArrayString(), toInt8(0)); -- { serverError TYPE_MISMATCH }
|
||||
SELECT arrayFold( x,acc -> x, 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayFold( x,y,acc -> x, [0, 1], 'not an array', toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayFold( x,acc -> x, [0, 1], [2, 3], toUInt8(0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
|
||||
SELECT arrayFold( x,y,acc -> x, [0, 1], [2, 3, 4], toUInt8(0)); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH }
|
||||
|
||||
SELECT arrayFold(x,acc -> acc+x, range(number), toInt64(0)) FROM system.numbers LIMIT 6;
|
||||
SELECT arrayFold(x,acc -> acc+x, range(number), number) FROM system.numbers LIMIT 6;
|
||||
SELECT arrayFold(x,acc -> arrayPushFront(acc, x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 6;
|
||||
SELECT arrayFold(x,acc -> x % 2 ? arrayPushFront(acc, x) : arrayPushBack(acc, x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 6;
|
||||
SELECT arrayFold(x,acc -> arrayPushFront(acc, (x, x+1)), range(number), [(toUInt64(0),toUInt64(0))]) FROM system.numbers LIMIT 6;
|
||||
SELECT arrayFold(x, acc -> concat(acc, arrayMap(z -> toString(x), [number])) , range(number), CAST([] as Array(String))) FROM system.numbers LIMIT 6;
|
||||
SELECT 'Const arrays';
|
||||
SELECT arrayFold( x,acc -> acc+x*2, [1, 2, 3, 4], toInt64(3));
|
||||
SELECT arrayFold( x,acc -> acc+x*2, emptyArrayInt64(), toInt64(3));
|
||||
SELECT arrayFold( x,y,acc -> acc+x*2+y*3, [1, 2, 3, 4], [5, 6, 7, 8], toInt64(3));
|
||||
SELECT arrayFold( x,acc -> arrayPushBack(acc, x), [1, 2, 3, 4], emptyArrayInt64());
|
||||
SELECT arrayFold( x,acc -> arrayPushFront(acc, x), [1, 2, 3, 4], emptyArrayInt64());
|
||||
SELECT arrayFold( x,acc -> (arrayPushFront(acc.1, x),arrayPushBack(acc.2, x)), [1, 2, 3, 4], (emptyArrayInt64(), emptyArrayInt64()));
|
||||
SELECT arrayFold( x,acc -> x%2 ? (arrayPushBack(acc.1, x), acc.2): (acc.1, arrayPushBack(acc.2, x)), [1, 2, 3, 4, 5, 6], (emptyArrayInt64(), emptyArrayInt64()));
|
||||
|
||||
SELECT 'Non-const arrays';
|
||||
SELECT arrayFold( x,acc -> acc+x, range(number), number) FROM system.numbers LIMIT 5;
|
||||
SELECT arrayFold( x,acc -> arrayPushFront(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;
|
||||
SELECT arrayFold( x,acc -> x%2 ? arrayPushFront(acc,x) : arrayPushBack(acc,x), range(number), emptyArrayUInt64()) FROM system.numbers LIMIT 5;
|
||||
|
@ -1046,6 +1046,7 @@ arrayFilter
|
||||
arrayFirst
|
||||
arrayFirstIndex
|
||||
arrayFlatten
|
||||
arrayFold
|
||||
arrayIntersect
|
||||
arrayJaccardIndex
|
||||
arrayJoin
|
||||
|
Loading…
Reference in New Issue
Block a user