diff --git a/dbms/src/Functions/CMakeLists.txt b/dbms/src/Functions/CMakeLists.txt index 3968e21f857..b63ce507845 100644 --- a/dbms/src/Functions/CMakeLists.txt +++ b/dbms/src/Functions/CMakeLists.txt @@ -65,6 +65,7 @@ generate_function_register(Array FunctionArrayHasAll FunctionArrayHasAny FunctionArrayIntersect + FunctionArrayResize ) diff --git a/dbms/src/Functions/FunctionsArray.cpp b/dbms/src/Functions/FunctionsArray.cpp index 228ac768d68..8240b95b9ac 100644 --- a/dbms/src/Functions/FunctionsArray.cpp +++ b/dbms/src/Functions/FunctionsArray.cpp @@ -3297,4 +3297,115 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, Mutable return ColumnArray::create(result_column, std::move(result_offsets_ptr)); } +/// Implementation of FunctionArrayResize. + +FunctionPtr FunctionArrayResize::create(const Context & context) +{ + return std::make_shared(context); +} + +String FunctionArrayResize::getName() const +{ + return name; +} + +DataTypePtr FunctionArrayResize::getReturnTypeImpl(const DataTypes & arguments) const +{ + size_t number_of_arguments = arguments.size(); + + if (number_of_arguments < 2 || number_of_arguments > 3) + throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " + + toString(number_of_arguments) + ", should be 2 or 3", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + if (arguments[0]->onlyNull()) + return arguments[0]; + + auto array_type = typeid_cast(arguments[0].get()); + if (!array_type) + throw Exception("First argument for function " + getName() + " must be an array but it has type " + + arguments[0]->getName() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + if (checkDataType(array_type->getNestedType().get())) + throw Exception("Function " + getName() + " cannot resize " + array_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + if (!removeNullable(arguments[1])->isInteger() && !arguments[1]->onlyNull()) + throw Exception( + "Argument " + toString(1) + " for function " + getName() + " must be integer but it has type " + + arguments[1]->getName() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + if (number_of_arguments) + return arguments[0]; + else + return std::make_shared(getLeastSupertype({array_type->getNestedType(), arguments[2]})); +} + +void FunctionArrayResize::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) +{ + const auto & return_type = block.getByPosition(result).type; + + if (return_type->onlyNull()) + { + block.getByPosition(result).column = return_type->createColumnConstWithDefaultValue(block.rows()); + return; + } + + auto result_column = return_type->createColumn(); + + auto array_column = block.getByPosition(arguments[0]).column; + auto size_column = block.getByPosition(arguments[1]).column; + + if (!block.getByPosition(arguments[0]).type->equals(*return_type)) + array_column = castColumn(block.getByPosition(arguments[0]), return_type, context); + + const DataTypePtr & return_nested_type = typeid_cast(*return_type).getNestedType(); + size_t size = array_column->size(); + + ColumnPtr appended_column; + if (arguments.size() == 3) + { + appended_column = block.getByPosition(arguments[2]).column; + if (!block.getByPosition(arguments[2]).type->equals(*return_nested_type)) + appended_column = castColumn(block.getByPosition(arguments[2]), return_nested_type, context); + } + else + appended_column = return_nested_type->createColumnConstWithDefaultValue(size); + + std::unique_ptr array_source; + std::unique_ptr value_source; + + bool is_const = false; + + if (auto const_array_column = typeid_cast(array_column.get())) + { + is_const = true; + array_column = const_array_column->getDataColumnPtr(); + } + + if (auto argument_column_array = typeid_cast(array_column.get())) + array_source = GatherUtils::createArraySource(*argument_column_array, is_const, size); + else + throw Exception{"First arguments for function " + getName() + " must be array.", ErrorCodes::LOGICAL_ERROR}; + + + bool is_appended_const = false; + if (auto const_appended_column = typeid_cast(appended_column.get())) + { + is_appended_const = true; + appended_column = const_appended_column->getDataColumnPtr(); + } + + value_source = GatherUtils::createValueSource(*appended_column, is_appended_const, size); + + auto sink = GatherUtils::createArraySink(typeid_cast(*result_column), size); + + if (size_column->isColumnConst()) + GatherUtils::resizeConstantSize(*array_source, *value_source, *sink, size_column->getInt(0)); + else + GatherUtils::resizeDynamicSize(*array_source, *value_source, *sink, *size_column); + + block.getByPosition(result).column = std::move(result_column); +} + + } diff --git a/dbms/src/Functions/FunctionsArray.h b/dbms/src/Functions/FunctionsArray.h index 3585bfe990d..fc429600b47 100644 --- a/dbms/src/Functions/FunctionsArray.h +++ b/dbms/src/Functions/FunctionsArray.h @@ -1641,6 +1641,30 @@ public: FunctionArrayHasAny(const Context & context) : FunctionArrayHasAllAny(context, false, name) {} }; + +class FunctionArrayResize : public IFunction +{ +public: + static constexpr auto name = "arrayResize"; + static FunctionPtr create(const Context & context); + FunctionArrayResize(const Context & context) : context(context) {}; + + String getName() const override; + + bool isVariadic() const override { return true; } + size_t getNumberOfArguments() const override { return 0; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override; + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result) override; + + bool useDefaultImplementationForConstants() const override { return true; } + bool useDefaultImplementationForNulls() const override { return false; } + +private: + const Context & context; +}; + struct NameHas { static constexpr auto name = "has"; }; struct NameIndexOf { static constexpr auto name = "indexOf"; }; struct NameCountEqual { static constexpr auto name = "countEqual"; }; diff --git a/dbms/src/Functions/GatherUtils/Algorithms.h b/dbms/src/Functions/GatherUtils/Algorithms.h index 425fc25f1cc..660fbaa2d2e 100644 --- a/dbms/src/Functions/GatherUtils/Algorithms.h +++ b/dbms/src/Functions/GatherUtils/Algorithms.h @@ -532,25 +532,27 @@ void resizeDynamicSize(ArraySource && array_source, ValueSource && value_source, if (size >= 0) { - if (array_size <= size) + auto length = static_cast(size); + if (array_size <= length) { writeSlice(array_source.getWhole(), sink); - for (auto i : ext::range(size, array_size)) + for (size_t i = array_size; i < length; ++i) writeSlice(value_source.getWhole(), sink); } else - writeSlice(array_source.getSliceFromLeft(0, size), sink); + writeSlice(array_source.getSliceFromLeft(0, length), sink); } else { - if (array_size <= -size) + auto length = static_cast(-size); + if (array_size <= length) { - for (auto i : ext::range(-size, array_size)) + for (size_t i = array_size; i < length; ++i) writeSlice(value_source.getWhole(), sink); writeSlice(array_source.getWhole(), sink); } else - writeSlice(array_source.getSliceFromLeft(0, -size), sink); + writeSlice(array_source.getSliceFromRight(length, length), sink); } } else @@ -567,30 +569,31 @@ void resizeConstantSize(ArraySource && array_source, ValueSource && value_source { while (!sink.isEnd()) { - size_t row_num = array_source.rowNum(); auto array_size = array_source.getElementSize(); if (size >= 0) { - if (array_size <= size) + auto length = static_cast(size); + if (array_size <= length) { writeSlice(array_source.getWhole(), sink); - for (auto i : ext::range(size, array_size)) + for (size_t i = array_size; i < length; ++i) writeSlice(value_source.getWhole(), sink); } else - writeSlice(array_source.getSliceFromLeft(0, size), sink); + writeSlice(array_source.getSliceFromLeft(0, length), sink); } else { - if (array_size <= -size) + auto length = static_cast(-size); + if (array_size <= length) { - for (auto i : ext::range(-size, array_size)) + for (size_t i = array_size; i < length; ++i) writeSlice(value_source.getWhole(), sink); writeSlice(array_source.getWhole(), sink); } else - writeSlice(array_source.getSliceFromLeft(0, -size), sink); + writeSlice(array_source.getSliceFromRight(length, length), sink); } value_source.next(); diff --git a/dbms/src/Functions/GatherUtils/GatherUtils.h b/dbms/src/Functions/GatherUtils/GatherUtils.h index c601c6b3bbd..da143af6fcf 100644 --- a/dbms/src/Functions/GatherUtils/GatherUtils.h +++ b/dbms/src/Functions/GatherUtils/GatherUtils.h @@ -54,6 +54,6 @@ void push(IArraySource & array_source, IValueSource & value_source, IArraySink & void resizeDynamicSize(IArraySource & array_source, IValueSource & value_source, IArraySink & sink, const IColumn & size_column); -void resizeConstantSize(IArraySource & array_source, IValueSource & value_source, IArraySink && sink, ssize_t size); +void resizeConstantSize(IArraySource & array_source, IValueSource & value_source, IArraySink & sink, ssize_t size); } diff --git a/dbms/src/Functions/GatherUtils/push.cpp b/dbms/src/Functions/GatherUtils/push.cpp index da1c2cf8a6f..2ba5b1a6322 100644 --- a/dbms/src/Functions/GatherUtils/push.cpp +++ b/dbms/src/Functions/GatherUtils/push.cpp @@ -20,6 +20,6 @@ struct ArrayPush : public ArrayAndValueSourceSelectorBySink void push(IArraySource & array_source, IValueSource & value_source, IArraySink & sink, bool push_front) { - return ArrayPush::select(sink, array_source, value_source, push_front); + ArrayPush::select(sink, array_source, value_source, push_front); } } diff --git a/dbms/src/Functions/GatherUtils/resizeConstantSize.cpp b/dbms/src/Functions/GatherUtils/resizeConstantSize.cpp new file mode 100644 index 00000000000..a34437ed9ff --- /dev/null +++ b/dbms/src/Functions/GatherUtils/resizeConstantSize.cpp @@ -0,0 +1,22 @@ +#include +#include + +namespace DB::GatherUtils +{ + +struct ArrayResizeConstant : public ArrayAndValueSourceSelectorBySink +{ + template + static void selectArrayAndValueSourceBySink( + ArraySource && array_source, ValueSource && value_source, Sink && sink, ssize_t size) + { + resizeConstantSize(array_source, value_source, sink, size); + } +}; + + +void resizeConstantSize(IArraySource & array_source, IValueSource & value_source, IArraySink & sink, ssize_t size) +{ + ArrayResizeConstant::select(sink, array_source, value_source, size); +} +} diff --git a/dbms/src/Functions/GatherUtils/resizeDynamicSize.cpp b/dbms/src/Functions/GatherUtils/resizeDynamicSize.cpp new file mode 100644 index 00000000000..b584e16dcaf --- /dev/null +++ b/dbms/src/Functions/GatherUtils/resizeDynamicSize.cpp @@ -0,0 +1,22 @@ +#include +#include + +namespace DB::GatherUtils +{ + +struct ArrayResizeDynamic : public ArrayAndValueSourceSelectorBySink +{ + template + static void selectArrayAndValueSourceBySink( + ArraySource && array_source, ValueSource && value_source, Sink && sink, const IColumn & size_column) + { + resizeDynamicSize(array_source, value_source, sink, size_column); + } +}; + + +void resizeDynamicSize(IArraySource & array_source, IValueSource & value_source, IArraySink & sink, const IColumn & size_column) +{ + ArrayResizeDynamic::select(sink, array_source, value_source, size_column); +} +} diff --git a/dbms/tests/queries/0_stateless/00557_array_resize.reference b/dbms/tests/queries/0_stateless/00557_array_resize.reference new file mode 100644 index 00000000000..e981342920a --- /dev/null +++ b/dbms/tests/queries/0_stateless/00557_array_resize.reference @@ -0,0 +1,13 @@ +[1,2,3,0,0,0,0,0,0,0] +[0,0,0,0,0,0,0,1,2,3] +[1,NULL,3,NULL,NULL,NULL,NULL,NULL,NULL,NULL] +[NULL,NULL,NULL,NULL,NULL,NULL,NULL,1,NULL,3] +[1,2,3] +[4,5,6] +[1,2,3,42,42] +[42,42,1,2,3] +['a','b','c','',''] +[[1,2],[3,4],[],[]] +[[],[],[1,2],[3,4]] +[[1,2],[3,4],[5,6],[5,6]] +[[5,6],[5,6],[1,2],[3,4]] diff --git a/dbms/tests/queries/0_stateless/00557_array_resize.sql b/dbms/tests/queries/0_stateless/00557_array_resize.sql new file mode 100644 index 00000000000..11b27bc9e0d --- /dev/null +++ b/dbms/tests/queries/0_stateless/00557_array_resize.sql @@ -0,0 +1,14 @@ +select arrayResize([1, 2, 3], 10); +select arrayResize([1, 2, 3], -10); +select arrayResize([1, Null, 3], 10); +select arrayResize([1, Null, 3], -10); +select arrayResize([1, 2, 3, 4, 5, 6], 3); +select arrayResize([1, 2, 3, 4, 5, 6], -3); +select arrayResize([1, 2, 3], 5, 42); +select arrayResize([1, 2, 3], -5, 42); +select arrayResize(['a', 'b', 'c'], 5); +select arrayResize([[1, 2], [3, 4]], 4); +select arrayResize([[1, 2], [3, 4]], -4); +select arrayResize([[1, 2], [3, 4]], 4, [5, 6]); +select arrayResize([[1, 2], [3, 4]], -4, [5, 6]); +