diff --git a/dbms/src/Functions/array/arraySplit.cpp b/dbms/src/Functions/array/arraySplit.cpp new file mode 100644 index 00000000000..dcb0c73e8a4 --- /dev/null +++ b/dbms/src/Functions/array/arraySplit.cpp @@ -0,0 +1,108 @@ +#include +#include +#include "FunctionArrayMapped.h" +#include + + +namespace DB +{ + +template +struct ArraySplitImpl +{ + static bool needBoolean() { return true; } + static bool needExpression() { return true; } + static bool needOneArray() { return false; } + + static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & array_element) + { + return std::make_shared( + std::make_shared(array_element) + ); + } + + static ColumnPtr execute(const ColumnArray & array, ColumnPtr mapped) + { + const ColumnUInt8 * column_cut = typeid_cast(&*mapped); + + const IColumn::Offsets & in_offsets = array.getOffsets(); + auto column_offsets_2 = ColumnArray::ColumnOffsets::create(); + auto column_offsets_1 = ColumnArray::ColumnOffsets::create(); + IColumn::Offsets & out_offsets_2 = column_offsets_2->getData(); + IColumn::Offsets & out_offsets_1 = column_offsets_1->getData(); + + if (column_cut) + { + const IColumn::Filter & cut = column_cut->getData(); + + size_t pos = 0; + + out_offsets_2.reserve(in_offsets.size()); // the actual size would be equal or larger + out_offsets_1.reserve(in_offsets.size()); + + for (size_t i = 0; i < in_offsets.size(); ++i) + { + pos += !Reverse; + for (; pos < in_offsets[i] - Reverse; ++pos) + { + if (cut[pos]) + out_offsets_2.push_back(pos + Reverse); + } + pos += Reverse; + + out_offsets_2.push_back(pos); + out_offsets_1.push_back(out_offsets_2.size()); + } + } + else + { + auto column_cut_const = checkAndGetColumnConst(&*mapped); + + if (!column_cut_const) + throw Exception("Unexpected type of cut column", ErrorCodes::ILLEGAL_COLUMN); + + if (column_cut_const->getValue()) + { + out_offsets_2.reserve(in_offsets.back()); + out_offsets_1.reserve(in_offsets.size()); + + for (size_t i = 0; i < in_offsets.back(); ++i) + out_offsets_2.push_back(i + 1); + for (size_t i = 0; i < in_offsets.size(); ++i) + out_offsets_1.push_back(in_offsets[i]); + } + else + { + out_offsets_2.reserve(in_offsets.size()); + out_offsets_1.reserve(in_offsets.size()); + + for (size_t i = 0; i < in_offsets.size(); ++i) + { + out_offsets_2.push_back(in_offsets[i]); + out_offsets_1.push_back(i + 1); + } + } + } + + return ColumnArray::create( + ColumnArray::create( + array.getDataPtr(), + std::move(column_offsets_2) + ), + std::move(column_offsets_1) + ); + } +}; + +struct NameArraySplit { static constexpr auto name = "arraySplit"; }; +struct NameArrayReverseSplit { static constexpr auto name = "arrayReverseSplit"; }; +using FunctionArraySplit = FunctionArrayMapped, NameArraySplit>; +using FunctionArrayReverseSplit = FunctionArrayMapped, NameArrayReverseSplit>; + +void registerFunctionArraySplit(FunctionFactory & factory) +{ + factory.registerFunction(); + factory.registerFunction(); +} + +} diff --git a/dbms/src/Functions/registerFunctionsHigherOrder.cpp b/dbms/src/Functions/registerFunctionsHigherOrder.cpp index e0948ebc913..2e8b678240b 100644 --- a/dbms/src/Functions/registerFunctionsHigherOrder.cpp +++ b/dbms/src/Functions/registerFunctionsHigherOrder.cpp @@ -11,6 +11,7 @@ void registerFunctionArrayAll(FunctionFactory &); void registerFunctionArraySum(FunctionFactory &); void registerFunctionArrayFirst(FunctionFactory &); void registerFunctionArrayFirstIndex(FunctionFactory &); +void registerFunctionArraySplit(FunctionFactory &); void registerFunctionsArraySort(FunctionFactory &); void registerFunctionArrayReverseSort(FunctionFactory &); void registerFunctionArrayCumSum(FunctionFactory &); @@ -27,6 +28,7 @@ void registerFunctionsHigherOrder(FunctionFactory & factory) registerFunctionArraySum(factory); registerFunctionArrayFirst(factory); registerFunctionArrayFirstIndex(factory); + registerFunctionArraySplit(factory); registerFunctionsArraySort(factory); registerFunctionArrayCumSum(factory); registerFunctionArrayCumSumNonNegative(factory); diff --git a/dbms/tests/queries/0_stateless/01015_array_split.reference b/dbms/tests/queries/0_stateless/01015_array_split.reference new file mode 100644 index 00000000000..ea9d36a95b2 --- /dev/null +++ b/dbms/tests/queries/0_stateless/01015_array_split.reference @@ -0,0 +1,16 @@ +[[1,2,3],[4,5]] +[[1],[2,3,4],[5]] +[[1,2,3,4,5]] +[[1,2,3,4,5]] +[[1],[2],[3],[4],[5]] +[[1],[2],[3],[4],[5]] +[[1,2],[3,4],[5]] +[[1],[2,3],[4,5]] +[[]] +[[]] +[] +[] +[[1]] +[[1]] +[[2]] +[[2]] diff --git a/dbms/tests/queries/0_stateless/01015_array_split.sql b/dbms/tests/queries/0_stateless/01015_array_split.sql new file mode 100644 index 00000000000..64d456ed724 --- /dev/null +++ b/dbms/tests/queries/0_stateless/01015_array_split.sql @@ -0,0 +1,19 @@ +SELECT arraySplit((x, y) -> y, [1, 2, 3, 4, 5], [1, 0, 0, 1, 0]); +SELECT arrayReverseSplit((x, y) -> y, [1, 2, 3, 4, 5], [1, 0, 0, 1, 0]); + +SELECT arraySplit(x -> 0, [1, 2, 3, 4, 5]); +SELECT arrayReverseSplit(x -> 0, [1, 2, 3, 4, 5]); +SELECT arraySplit(x -> 1, [1, 2, 3, 4, 5]); +SELECT arrayReverseSplit(x -> 1, [1, 2, 3, 4, 5]); +SELECT arraySplit(x -> x % 2 = 1, [1, 2, 3, 4, 5]); +SELECT arrayReverseSplit(x -> x % 2 = 1, [1, 2, 3, 4, 5]); + +SELECT arraySplit(x -> 0, []); +SELECT arrayReverseSplit(x -> 0, []); +SELECT arraySplit(x -> 1, []); +SELECT arrayReverseSplit(x -> 1, []); + +SELECT arraySplit(x -> x % 2 = 1, [1]); +SELECT arrayReverseSplit(x -> x % 2 = 1, [1]); +SELECT arraySplit(x -> x % 2 = 1, [2]); +SELECT arrayReverseSplit(x -> x % 2 = 1, [2]);