diff --git a/dbms/src/Functions/array/arrayFill.cpp b/dbms/src/Functions/array/arrayFill.cpp new file mode 100644 index 00000000000..4c2dc5659b8 --- /dev/null +++ b/dbms/src/Functions/array/arrayFill.cpp @@ -0,0 +1,127 @@ +#include +#include +#include "FunctionArrayMapped.h" +#include + + +namespace DB +{ + +template +struct ArrayFillImpl +{ + static bool needBoolean() { return true; } + static bool needExpression() { return true; } + static bool needOneArray() { return false; } + + static DataTypePtr getReturnType(const DataTypePtr & /*expression_return*/, const DataTypePtr & array_element) + { + return std::make_shared(array_element); + } + + static ColumnPtr execute(const ColumnArray & array, ColumnPtr mapped) + { + const ColumnUInt8 * column_fill = typeid_cast(&*mapped); + + const IColumn & in_data = array.getData(); + const IColumn::Offsets & in_offsets = array.getOffsets(); + auto column_data = in_data.cloneEmpty(); + IColumn & out_data = *column_data.get(); + + if (column_fill) + { + const IColumn::Filter & fill = column_fill->getData(); + + size_t array_begin = 0; + size_t array_end = 0; + size_t begin = 0; + size_t end = 0; + + out_data.reserve(in_data.size()); + + for (size_t i = 0; i < in_offsets.size(); ++i) + { + array_end = in_offsets[i] - 1; + + for (; end <= array_end; ++end) + { + if (end == array_end || fill[end + 1] != fill[begin]) { + if (fill[begin]) + { + if constexpr (Reverse) + { + if (end == array_end) + out_data.insertManyFrom(in_data, array_end, end + 1 - begin); + else + out_data.insertManyFrom(in_data, end + 1, end + 1 - begin); + } + else + { + if (begin == array_begin) + out_data.insertManyFrom(in_data, array_begin, end + 1 - begin); + else + out_data.insertManyFrom(in_data, begin - 1, end + 1 - begin); + } + } + else + out_data.insertRangeFrom(in_data, begin, end + 1 - begin); + + begin = end + 1; + } + } + + array_begin = array_end + 1; + } + } + else + { + auto column_fill_const = checkAndGetColumnConst(&*mapped); + + if (!column_fill_const) + throw Exception("Unexpected type of cut column", ErrorCodes::ILLEGAL_COLUMN); + + if (column_fill_const->getValue()) + { + size_t array_begin = 0; + size_t array_end = 0; + + out_data.reserve(in_data.size()); + + for (size_t i = 0; i < in_offsets.size(); ++i) + { + array_end = in_offsets[i] - 1; + + if constexpr (Reverse) + out_data.insertManyFrom(in_data, array_end, array_end + 1 - array_begin); + else + out_data.insertManyFrom(in_data, array_begin, array_end + 1 - array_begin); + + array_begin = array_end + 1; + } + } + else + return ColumnArray::create( + array.getDataPtr(), + array.getOffsetsPtr() + ); + } + + return ColumnArray::create( + std::move(column_data), + array.getOffsetsPtr() + ); + } +}; + +struct NameArrayFill { static constexpr auto name = "arrayFill"; }; +struct NameArrayReverseFill { static constexpr auto name = "arrayReverseFill"; }; +using FunctionArrayFill = FunctionArrayMapped, NameArrayFill>; +using FunctionArrayReverseFill = FunctionArrayMapped, NameArrayReverseFill>; + +void registerFunctionArrayFill(FunctionFactory & factory) +{ + factory.registerFunction(); + factory.registerFunction(); +} + +} diff --git a/dbms/src/Functions/registerFunctionsHigherOrder.cpp b/dbms/src/Functions/registerFunctionsHigherOrder.cpp index 46e89850582..8511c0c412c 100644 --- a/dbms/src/Functions/registerFunctionsHigherOrder.cpp +++ b/dbms/src/Functions/registerFunctionsHigherOrder.cpp @@ -11,6 +11,7 @@ void registerFunctionArrayAll(FunctionFactory &); void registerFunctionArraySum(FunctionFactory &); void registerFunctionArrayFirst(FunctionFactory &); void registerFunctionArrayFirstIndex(FunctionFactory &); +void registerFunctionArrayFill(FunctionFactory &); void registerFunctionArraySplit(FunctionFactory &); void registerFunctionArraySort(FunctionFactory &); void registerFunctionArrayCumSum(FunctionFactory &); @@ -27,6 +28,7 @@ void registerFunctionsHigherOrder(FunctionFactory & factory) registerFunctionArraySum(factory); registerFunctionArrayFirst(factory); registerFunctionArrayFirstIndex(factory); + registerFunctionArrayFill(factory); registerFunctionArraySplit(factory); registerFunctionArraySort(factory); registerFunctionArrayCumSum(factory);