From a8b78abc543863ee404395c9b24145a28c6aee57 Mon Sep 17 00:00:00 2001 From: Joanna Hulboj Date: Fri, 20 Jan 2023 20:39:33 +0000 Subject: [PATCH] Added arrayPartialShuffle function --- src/Functions/array/arrayShuffle.cpp | 125 ++++++++++++++---- .../0_stateless/02523_array_shuffle.reference | 42 ++++++ .../0_stateless/02523_array_shuffle.sql | 42 ++++++ 3 files changed, 186 insertions(+), 23 deletions(-) diff --git a/src/Functions/array/arrayShuffle.cpp b/src/Functions/array/arrayShuffle.cpp index 3941eb7271d..47608a8524e 100644 --- a/src/Functions/array/arrayShuffle.cpp +++ b/src/Functions/array/arrayShuffle.cpp @@ -6,11 +6,13 @@ #include #include #include -#include #include #include +#include #include +#include + #include #include @@ -28,52 +30,83 @@ namespace ErrorCodes * arrayShuffle(arr) * arrayShuffle(arr, seed) */ -class FunctionArrayShuffle : public IFunction +struct FunctionArrayShuffleTraits +{ + static constexpr auto name = "arrayShuffle"; + static constexpr auto has_limit = false; // Permute the whole array + static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {1}; } + static constexpr auto max_num_params = 2; // array[, seed] + static constexpr auto seed_param_idx = 1; +}; + +/** Partial shuffle array elements + * arrayPartialShuffle(arr) + * arrayPartialShuffle(arr, limit) + * arrayPartialShuffle(arr, limit, seed) + */ +struct FunctionArrayPartialShuffleTraits +{ + static constexpr auto name = "arrayPartialShuffle"; + static constexpr auto has_limit = true; + static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {1, 2}; } + static constexpr auto max_num_params = 3; // array[, limit[, seed]] + static constexpr auto seed_param_idx = 2; +}; + +template +class FunctionArrayShuffleImpl : public IFunction { public: - static constexpr auto name = "arrayShuffle"; - - static FunctionPtr create(ContextPtr) { return std::make_shared(); } + static constexpr auto name = Traits::name; String getName() const override { return name; } bool isVariadic() const override { return true; } size_t getNumberOfArguments() const override { return 0; } - ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } + ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return Traits::getArgumentsThatAreAlwaysConstant(); } + bool useDefaultImplementationForConstants() const override { return true; } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + + static FunctionPtr create(ContextPtr) { return std::make_shared>(); } DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { - if (arguments.size() > 2 || arguments.empty()) + if (arguments.size() > Traits::max_num_params || arguments.empty()) { throw Exception( - ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function '{}' needs 1 or 2 arguments, passed {}.", getName(), arguments.size()); + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Function '{}' needs from 1 to {} arguments, passed {}.", + getName(), + Traits::max_num_params, + arguments.size()); } const DataTypeArray * array_type = checkAndGetDataType(arguments[0].get()); if (!array_type) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function '{}' must be array", getName()); - if (arguments.size() == 2) + auto check_is_integral = [&](auto param_idx) { - WhichDataType which(arguments[1]); + WhichDataType which(arguments[param_idx]); if (!which.isUInt() && !which.isInt()) throw Exception( - "Illegal type " + arguments[1]->getName() + " of argument of function " + getName() + " (must be UInt or Int)", + "Illegal type " + arguments[param_idx]->getName() + " of argument of function " + getName() + " (must be UInt or Int)", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - } + }; + + for (size_t idx = 1; idx < arguments.size(); ++idx) + check_is_integral(idx); return arguments[0]; } - bool useDefaultImplementationForConstants() const override { return true; } - bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t) const override; private: - static ColumnPtr executeGeneric(const ColumnArray & array, ColumnPtr mapped, pcg64_fast & rng); + static ColumnPtr executeGeneric(const ColumnArray & array, pcg64_fast & rng, size_t limit); }; -ColumnPtr FunctionArrayShuffle::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const +template +ColumnPtr FunctionArrayShuffleImpl::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const { const ColumnArray * array = checkAndGetColumn(arguments[0].column.get()); if (!array) @@ -82,17 +115,32 @@ ColumnPtr FunctionArrayShuffle::executeImpl(const ColumnsWithTypeAndName & argum const auto seed = [&]() -> uint64_t { - if (arguments.size() == 1) + // If present, seed comes as the last argument + if (arguments.size() != Traits::max_num_params) return randomSeed(); - const auto * val = arguments[1].column.get(); + const auto * val = arguments[Traits::seed_param_idx].column.get(); return val->getUInt(0); }(); pcg64_fast rng(seed); - return executeGeneric(*array, array->getDataPtr(), rng); + size_t limit = [&] + { + if constexpr (Traits::has_limit) + { + if (arguments.size() > 1) + { + const auto * val = arguments[1].column.get(); + return val->getUInt(0); + } + } + return static_cast(0); + }(); + + return executeGeneric(*array, rng, limit); } -ColumnPtr FunctionArrayShuffle::executeGeneric(const ColumnArray & array, ColumnPtr /*mapped*/, pcg64_fast & rng) +template +ColumnPtr FunctionArrayShuffleImpl::executeGeneric(const ColumnArray & array, pcg64_fast & rng, size_t limit [[maybe_unused]]) { const ColumnArray::Offsets & offsets = array.getOffsets(); @@ -105,7 +153,15 @@ ColumnPtr FunctionArrayShuffle::executeGeneric(const ColumnArray & array, Column for (size_t i = 0; i < size; ++i) { auto next_offset = offsets[i]; - std::shuffle(&permutation[current_offset], &permutation[next_offset], rng); + if constexpr (Traits::has_limit) + { + if (limit && next_offset > limit) + { + partial_shuffle(&permutation[current_offset], &permutation[next_offset], limit, rng); + break; + } + } + shuffle(&permutation[current_offset], &permutation[next_offset], rng); current_offset = next_offset; } return ColumnArray::create(array.getData().permute(permutation, 0), array.getOffsetsPtr()); @@ -113,7 +169,7 @@ ColumnPtr FunctionArrayShuffle::executeGeneric(const ColumnArray & array, Column REGISTER_FUNCTION(ArrayShuffle) { - factory.registerFunction( + factory.registerFunction>( { R"( Returns an array of the same size as the original array containing the elements in shuffled order. @@ -131,6 +187,29 @@ It is possible to override the seed to produce stable results: Documentation::Categories{"Array"} }, FunctionFactory::CaseInsensitive); + factory.registerFunction>( + { + R"( +Returns an array of the same size as the original array where elements in range [0..limit) are a random +subset of the original array. Remaining [limit..n) shall contain the elements not in [0..limit) range in undefined order. +Value of limit shall be in range [0..n]. Values outside of that range are equivalent to performing full arrayShuffle: +[example:no_limit1] +[example:no_limit2] + +If no seed is provided a random one will be used: +[example:random_seed] + +It is possible to override the seed to produce stable results: +[example:explicit_seed] +)", + Documentation::Examples{ + {"no_limit1", "SELECT arrayPartialShuffle([1, 2, 3, 4], 0)"}, + {"no_limit2", "SELECT arrayPartialShuffle([1, 2, 3, 4])"}, + {"random_seed", "SELECT arrayPartialShuffle([1, 2, 3, 4], 2)"}, + {"explicit_seed", "SELECT arrayShuffle([1, 2, 3, 4], 2, 41)"}}, + Documentation::Categories{"Array"} + }, + FunctionFactory::CaseInsensitive); } } diff --git a/tests/queries/0_stateless/02523_array_shuffle.reference b/tests/queries/0_stateless/02523_array_shuffle.reference index a92ad2a05c6..2263f8dc92a 100644 --- a/tests/queries/0_stateless/02523_array_shuffle.reference +++ b/tests/queries/0_stateless/02523_array_shuffle.reference @@ -18,3 +18,45 @@ [10,72,11,18,73,76,46,71,44,35,9,0,97,53,13,32,51,30,3,68,5,48,67,90,20,27,38,19,54,21,83,84,1,22,56,81,91,77,36,63,33,39,24,40,4,99,14,23,94,29,26,96,2,28,31,57,42,88,12,47,58,8,37,82,92,34,6,60,25,43,50,74,70,52,55,62,17,79,65,93,86,7,16,41,59,75,80,45,69,89,85,87,95,64,61,98,49,78,66,15] [(3,-3),(1,-1),(99999999,-99999999)] [(3,'A'),(1,NULL),(2,'a')] +[] +[] +[] +[9223372036854775808] +[9223372036854775808] +[9223372036854775808] +[10,9,4,2,5,6,7,1,8,3] +[10.1,9,4,2,5,6,7,1,8,3] +[9223372036854775808,9,4,2,5,6,7,1,8,3] +[NULL,9,4,2,5,6,7,1,8,3] +['789','123','ABC','000','456'] +['789','123','ABC',NULL,'456'] +['imposter','storage','sensation','uniform','tiger','terminal'] +[NULL,'storage','sensation','uniform','tiger','terminal'] +[NULL] +[NULL,NULL] +[[10,20,30,40],[1,2,3,4],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64],[-1,-2,-3,-4]] +[[10,20,30,40],[1,2,3,4],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64],[NULL,-2,-3,-4]] +[10,72,11,18,73,76,46,71,44,35,9,0,97,53,13,32,51,30,3,68,5,48,67,90,20,27,38,19,54,21,83,84,1,22,56,81,91,77,36,63,33,39,24,40,4,99,14,23,94,29,26,96,2,28,31,57,42,88,12,47,58,8,37,82,92,34,6,60,25,43,50,74,70,52,55,62,17,79,65,93,86,7,16,41,59,75,80,45,69,89,85,87,95,64,61,98,49,78,66,15] +[10,72,11,18,73,76,46,71,44,35,9,0,97,53,13,32,51,30,3,68,5,48,67,90,20,27,38,19,54,21,83,84,1,22,56,81,91,77,36,63,33,39,24,40,4,99,14,23,94,29,26,96,2,28,31,57,42,88,12,47,58,8,37,82,92,34,6,60,25,43,50,74,70,52,55,62,17,79,65,93,86,7,16,41,59,75,80,45,69,89,85,87,95,64,61,98,49,78,66,15] +[(3,-3),(1,-1),(99999999,-99999999)] +[(3,'A'),(1,NULL),(2,'a')] +[NULL,NULL,NULL] +[10,2,3,4,5,6,7,8,9,1] +[10,9,3,4,5,6,7,8,2,1] +[10,9,4,2,5,6,7,8,3,1] +[10,9,4,2,5,6,7,1,3,8] +[10,9,4,2,5,6,7,1,8,3] +[10,9,4,2,5,6,7,1,8,3] +[10.1,9,4,2,5,6,7,8,3,1] +[9223372036854775808,9,4,2,5,6,7,8,3,1] +[NULL,9,4,2,5,6,7,8,3,1] +['789','123','ABC','456','000'] +['789','123','ABC','456',NULL] +['imposter','storage','sensation','terminal','uniform','tiger'] +[NULL,'storage','sensation','terminal','uniform','tiger'] +[[10,20,30,40],[1,2,3,4],[-1,-2,-3,-4],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64]] +[[10,20,30,40],[1,2,3,4],[NULL,-2,-3,-4],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64]] +[10,72,11,18,73,76,46,71,44,35,9,0,97,53,13,32,51,30,3,68,20,21,22,23,24,25,26,27,28,29,17,31,15,33,34,2,36,37,38,39,40,41,42,43,8,45,6,47,48,49,50,16,52,14,54,55,56,57,58,59,60,61,62,63,64,65,66,67,19,69,70,7,1,4,74,75,5,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,12,98,99] +[10,72,11,18,73,76,46,71,44,35,9,0,97,53,13,32,51,30,3,68,20,21,22,23,24,25,26,27,28,29,17,31,15,33,34,2,36,37,38,39,40,41,42,43,8,45,6,47,48,49,50,16,52,14,54,55,56,57,58,59,60,61,62,63,64,65,66,67,19,69,70,7,1,4,74,75,5,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,12,98,99] +[(3,-3),(1,-1),(99999999,-99999999)] +[(3,'A'),(1,NULL),(2,'a')] diff --git a/tests/queries/0_stateless/02523_array_shuffle.sql b/tests/queries/0_stateless/02523_array_shuffle.sql index ecbc9e649d4..dfeb75e01c5 100644 --- a/tests/queries/0_stateless/02523_array_shuffle.sql +++ b/tests/queries/0_stateless/02523_array_shuffle.sql @@ -18,6 +18,48 @@ SELECT arrayShuffle(groupArray(x),0xbad_cafe) FROM (SELECT number as x from syst SELECT arrayShuffle(groupArray(toUInt64(x)),0xbad_cafe) FROM (SELECT number as x from system.numbers LIMIT 100); SELECT arrayShuffle([tuple(1, -1), tuple(99999999, -99999999), tuple(3, -3)], 0xbad_cafe); SELECT arrayShuffle([tuple(1, NULL), tuple(2, 'a'), tuple(3, 'A')], 0xbad_cafe); +SELECT arrayPartialShuffle([]); -- trivial cases (equivalent to arrayShuffle) +SELECT arrayPartialShuffle([], 0); +SELECT arrayPartialShuffle([], 0, 0xbad_cafe); +SELECT arrayPartialShuffle([9223372036854775808]); +SELECT arrayPartialShuffle([9223372036854775808], 0); +SELECT arrayPartialShuffle([9223372036854775808], 0, 0xbad_cafe); +SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10], 0, 0xbad_cafe); +SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10.1], 0, 0xbad_cafe); +SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,9223372036854775808], 0, 0xbad_cafe); +SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,NULL], 0, 0xbad_cafe); +SELECT arrayPartialShuffle([toFixedString('123', 3), toFixedString('456', 3), toFixedString('789', 3), toFixedString('ABC', 3), toFixedString('000', 3)], 0, 0xbad_cafe); +SELECT arrayPartialShuffle([toFixedString('123', 3), toFixedString('456', 3), toFixedString('789', 3), toFixedString('ABC', 3), NULL], 0, 0xbad_cafe); +SELECT arrayPartialShuffle(['storage','tiger','imposter','terminal','uniform','sensation'], 0, 0xbad_cafe); +SELECT arrayPartialShuffle(['storage','tiger',NULL,'terminal','uniform','sensation'], 0, 0xbad_cafe); +SELECT arrayPartialShuffle([NULL]); +SELECT arrayPartialShuffle([NULL,NULL]); +SELECT arrayPartialShuffle([[1,2,3,4],[-1,-2,-3,-4],[10,20,30,40],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64]], 0, 0xbad_cafe); +SELECT arrayPartialShuffle([[1,2,3,4],[NULL,-2,-3,-4],[10,20,30,40],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64]], 0, 0xbad_cafe); +SELECT arrayPartialShuffle(groupArray(x),0,0xbad_cafe) FROM (SELECT number as x from system.numbers LIMIT 100); +SELECT arrayPartialShuffle(groupArray(toUInt64(x)),0,0xbad_cafe) FROM (SELECT number as x from system.numbers LIMIT 100); +SELECT arrayPartialShuffle([tuple(1, -1), tuple(99999999, -99999999), tuple(3, -3)], 0, 0xbad_cafe); +SELECT arrayPartialShuffle([tuple(1, NULL), tuple(2, 'a'), tuple(3, 'A')], 0, 0xbad_cafe); +SELECT arrayPartialShuffle([NULL,NULL,NULL], 2); -- other, mostly non-trivial cases +SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10], 1, 0xbad_cafe); +SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10], 2, 0xbad_cafe); +SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10], 4, 0xbad_cafe); +SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10], 8, 0xbad_cafe); +SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10], 9, 0xbad_cafe); +SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10], 10, 0xbad_cafe); +SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10.1], 4, 0xbad_cafe); +SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,9223372036854775808], 4, 0xbad_cafe); +SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,NULL], 4, 0xbad_cafe); +SELECT arrayPartialShuffle([toFixedString('123', 3), toFixedString('456', 3), toFixedString('789', 3), toFixedString('ABC', 3), toFixedString('000', 3)], 3, 0xbad_cafe); +SELECT arrayPartialShuffle([toFixedString('123', 3), toFixedString('456', 3), toFixedString('789', 3), toFixedString('ABC', 3), NULL], 3, 0xbad_cafe); +SELECT arrayPartialShuffle(['storage','tiger','imposter','terminal','uniform','sensation'], 3, 0xbad_cafe); +SELECT arrayPartialShuffle(['storage','tiger',NULL,'terminal','uniform','sensation'], 3, 0xbad_cafe); +SELECT arrayPartialShuffle([[1,2,3,4],[-1,-2,-3,-4],[10,20,30,40],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64]], 2, 0xbad_cafe); +SELECT arrayPartialShuffle([[1,2,3,4],[NULL,-2,-3,-4],[10,20,30,40],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64]], 2, 0xbad_cafe); +SELECT arrayPartialShuffle(groupArray(x),20,0xbad_cafe) FROM (SELECT number as x from system.numbers LIMIT 100); +SELECT arrayPartialShuffle(groupArray(toUInt64(x)),20,0xbad_cafe) FROM (SELECT number as x from system.numbers LIMIT 100); +SELECT arrayPartialShuffle([tuple(1, -1), tuple(99999999, -99999999), tuple(3, -3)], 2, 0xbad_cafe); +SELECT arrayPartialShuffle([tuple(1, NULL), tuple(2, 'a'), tuple(3, 'A')], 2, 0xbad_cafe); SELECT arrayShuffle(1); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } SELECT arrayShuffle([1], 'a'); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } SELECT arrayShuffle([1], 1.1); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }