Added arrayPartialShuffle function

This commit is contained in:
Joanna Hulboj 2023-01-20 20:39:33 +00:00
parent 31eb936457
commit a8b78abc54
3 changed files with 186 additions and 23 deletions

View File

@ -6,11 +6,13 @@
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <pcg_random.hpp>
#include <Common/assert_cast.h>
#include <Common/randomSeed.h>
#include <Common/shuffle.h>
#include <Common/typeid_cast.h>
#include <pcg_random.hpp>
#include <algorithm>
#include <numeric>
@ -28,52 +30,83 @@ namespace ErrorCodes
* arrayShuffle(arr)
* arrayShuffle(arr, seed)
*/
class FunctionArrayShuffle : public IFunction
struct FunctionArrayShuffleTraits
{
static constexpr auto name = "arrayShuffle";
static constexpr auto has_limit = false; // Permute the whole array
static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {1}; }
static constexpr auto max_num_params = 2; // array[, seed]
static constexpr auto seed_param_idx = 1;
};
/** Partial shuffle array elements
* arrayPartialShuffle(arr)
* arrayPartialShuffle(arr, limit)
* arrayPartialShuffle(arr, limit, seed)
*/
struct FunctionArrayPartialShuffleTraits
{
static constexpr auto name = "arrayPartialShuffle";
static constexpr auto has_limit = true;
static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {1, 2}; }
static constexpr auto max_num_params = 3; // array[, limit[, seed]]
static constexpr auto seed_param_idx = 2;
};
template <typename Traits>
class FunctionArrayShuffleImpl : public IFunction
{
public:
static constexpr auto name = "arrayShuffle";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayShuffle>(); }
static constexpr auto name = Traits::name;
String getName() const override { return name; }
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return Traits::getArgumentsThatAreAlwaysConstant(); }
bool useDefaultImplementationForConstants() const override { return true; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayShuffleImpl<Traits>>(); }
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (arguments.size() > 2 || arguments.empty())
if (arguments.size() > Traits::max_num_params || arguments.empty())
{
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function '{}' needs 1 or 2 arguments, passed {}.", getName(), arguments.size());
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function '{}' needs from 1 to {} arguments, passed {}.",
getName(),
Traits::max_num_params,
arguments.size());
}
const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].get());
if (!array_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function '{}' must be array", getName());
if (arguments.size() == 2)
auto check_is_integral = [&](auto param_idx)
{
WhichDataType which(arguments[1]);
WhichDataType which(arguments[param_idx]);
if (!which.isUInt() && !which.isInt())
throw Exception(
"Illegal type " + arguments[1]->getName() + " of argument of function " + getName() + " (must be UInt or Int)",
"Illegal type " + arguments[param_idx]->getName() + " of argument of function " + getName() + " (must be UInt or Int)",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
};
for (size_t idx = 1; idx < arguments.size(); ++idx)
check_is_integral(idx);
return arguments[0];
}
bool useDefaultImplementationForConstants() const override { return true; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t) const override;
private:
static ColumnPtr executeGeneric(const ColumnArray & array, ColumnPtr mapped, pcg64_fast & rng);
static ColumnPtr executeGeneric(const ColumnArray & array, pcg64_fast & rng, size_t limit);
};
ColumnPtr FunctionArrayShuffle::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const
template <typename Traits>
ColumnPtr FunctionArrayShuffleImpl<Traits>::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const
{
const ColumnArray * array = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
if (!array)
@ -82,17 +115,32 @@ ColumnPtr FunctionArrayShuffle::executeImpl(const ColumnsWithTypeAndName & argum
const auto seed = [&]() -> uint64_t
{
if (arguments.size() == 1)
// If present, seed comes as the last argument
if (arguments.size() != Traits::max_num_params)
return randomSeed();
const auto * val = arguments[1].column.get();
const auto * val = arguments[Traits::seed_param_idx].column.get();
return val->getUInt(0);
}();
pcg64_fast rng(seed);
return executeGeneric(*array, array->getDataPtr(), rng);
size_t limit = [&]
{
if constexpr (Traits::has_limit)
{
if (arguments.size() > 1)
{
const auto * val = arguments[1].column.get();
return val->getUInt(0);
}
}
return static_cast<size_t>(0);
}();
return executeGeneric(*array, rng, limit);
}
ColumnPtr FunctionArrayShuffle::executeGeneric(const ColumnArray & array, ColumnPtr /*mapped*/, pcg64_fast & rng)
template <typename Traits>
ColumnPtr FunctionArrayShuffleImpl<Traits>::executeGeneric(const ColumnArray & array, pcg64_fast & rng, size_t limit [[maybe_unused]])
{
const ColumnArray::Offsets & offsets = array.getOffsets();
@ -105,7 +153,15 @@ ColumnPtr FunctionArrayShuffle::executeGeneric(const ColumnArray & array, Column
for (size_t i = 0; i < size; ++i)
{
auto next_offset = offsets[i];
std::shuffle(&permutation[current_offset], &permutation[next_offset], rng);
if constexpr (Traits::has_limit)
{
if (limit && next_offset > limit)
{
partial_shuffle(&permutation[current_offset], &permutation[next_offset], limit, rng);
break;
}
}
shuffle(&permutation[current_offset], &permutation[next_offset], rng);
current_offset = next_offset;
}
return ColumnArray::create(array.getData().permute(permutation, 0), array.getOffsetsPtr());
@ -113,7 +169,7 @@ ColumnPtr FunctionArrayShuffle::executeGeneric(const ColumnArray & array, Column
REGISTER_FUNCTION(ArrayShuffle)
{
factory.registerFunction<FunctionArrayShuffle>(
factory.registerFunction<FunctionArrayShuffleImpl<FunctionArrayShuffleTraits>>(
{
R"(
Returns an array of the same size as the original array containing the elements in shuffled order.
@ -131,6 +187,29 @@ It is possible to override the seed to produce stable results:
Documentation::Categories{"Array"}
},
FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionArrayShuffleImpl<FunctionArrayPartialShuffleTraits>>(
{
R"(
Returns an array of the same size as the original array where elements in range [0..limit) are a random
subset of the original array. Remaining [limit..n) shall contain the elements not in [0..limit) range in undefined order.
Value of limit shall be in range [0..n]. Values outside of that range are equivalent to performing full arrayShuffle:
[example:no_limit1]
[example:no_limit2]
If no seed is provided a random one will be used:
[example:random_seed]
It is possible to override the seed to produce stable results:
[example:explicit_seed]
)",
Documentation::Examples{
{"no_limit1", "SELECT arrayPartialShuffle([1, 2, 3, 4], 0)"},
{"no_limit2", "SELECT arrayPartialShuffle([1, 2, 3, 4])"},
{"random_seed", "SELECT arrayPartialShuffle([1, 2, 3, 4], 2)"},
{"explicit_seed", "SELECT arrayShuffle([1, 2, 3, 4], 2, 41)"}},
Documentation::Categories{"Array"}
},
FunctionFactory::CaseInsensitive);
}
}

View File

@ -18,3 +18,45 @@
[10,72,11,18,73,76,46,71,44,35,9,0,97,53,13,32,51,30,3,68,5,48,67,90,20,27,38,19,54,21,83,84,1,22,56,81,91,77,36,63,33,39,24,40,4,99,14,23,94,29,26,96,2,28,31,57,42,88,12,47,58,8,37,82,92,34,6,60,25,43,50,74,70,52,55,62,17,79,65,93,86,7,16,41,59,75,80,45,69,89,85,87,95,64,61,98,49,78,66,15]
[(3,-3),(1,-1),(99999999,-99999999)]
[(3,'A'),(1,NULL),(2,'a')]
[]
[]
[]
[9223372036854775808]
[9223372036854775808]
[9223372036854775808]
[10,9,4,2,5,6,7,1,8,3]
[10.1,9,4,2,5,6,7,1,8,3]
[9223372036854775808,9,4,2,5,6,7,1,8,3]
[NULL,9,4,2,5,6,7,1,8,3]
['789','123','ABC','000','456']
['789','123','ABC',NULL,'456']
['imposter','storage','sensation','uniform','tiger','terminal']
[NULL,'storage','sensation','uniform','tiger','terminal']
[NULL]
[NULL,NULL]
[[10,20,30,40],[1,2,3,4],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64],[-1,-2,-3,-4]]
[[10,20,30,40],[1,2,3,4],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64],[NULL,-2,-3,-4]]
[10,72,11,18,73,76,46,71,44,35,9,0,97,53,13,32,51,30,3,68,5,48,67,90,20,27,38,19,54,21,83,84,1,22,56,81,91,77,36,63,33,39,24,40,4,99,14,23,94,29,26,96,2,28,31,57,42,88,12,47,58,8,37,82,92,34,6,60,25,43,50,74,70,52,55,62,17,79,65,93,86,7,16,41,59,75,80,45,69,89,85,87,95,64,61,98,49,78,66,15]
[10,72,11,18,73,76,46,71,44,35,9,0,97,53,13,32,51,30,3,68,5,48,67,90,20,27,38,19,54,21,83,84,1,22,56,81,91,77,36,63,33,39,24,40,4,99,14,23,94,29,26,96,2,28,31,57,42,88,12,47,58,8,37,82,92,34,6,60,25,43,50,74,70,52,55,62,17,79,65,93,86,7,16,41,59,75,80,45,69,89,85,87,95,64,61,98,49,78,66,15]
[(3,-3),(1,-1),(99999999,-99999999)]
[(3,'A'),(1,NULL),(2,'a')]
[NULL,NULL,NULL]
[10,2,3,4,5,6,7,8,9,1]
[10,9,3,4,5,6,7,8,2,1]
[10,9,4,2,5,6,7,8,3,1]
[10,9,4,2,5,6,7,1,3,8]
[10,9,4,2,5,6,7,1,8,3]
[10,9,4,2,5,6,7,1,8,3]
[10.1,9,4,2,5,6,7,8,3,1]
[9223372036854775808,9,4,2,5,6,7,8,3,1]
[NULL,9,4,2,5,6,7,8,3,1]
['789','123','ABC','456','000']
['789','123','ABC','456',NULL]
['imposter','storage','sensation','terminal','uniform','tiger']
[NULL,'storage','sensation','terminal','uniform','tiger']
[[10,20,30,40],[1,2,3,4],[-1,-2,-3,-4],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64]]
[[10,20,30,40],[1,2,3,4],[NULL,-2,-3,-4],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64]]
[10,72,11,18,73,76,46,71,44,35,9,0,97,53,13,32,51,30,3,68,20,21,22,23,24,25,26,27,28,29,17,31,15,33,34,2,36,37,38,39,40,41,42,43,8,45,6,47,48,49,50,16,52,14,54,55,56,57,58,59,60,61,62,63,64,65,66,67,19,69,70,7,1,4,74,75,5,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,12,98,99]
[10,72,11,18,73,76,46,71,44,35,9,0,97,53,13,32,51,30,3,68,20,21,22,23,24,25,26,27,28,29,17,31,15,33,34,2,36,37,38,39,40,41,42,43,8,45,6,47,48,49,50,16,52,14,54,55,56,57,58,59,60,61,62,63,64,65,66,67,19,69,70,7,1,4,74,75,5,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,12,98,99]
[(3,-3),(1,-1),(99999999,-99999999)]
[(3,'A'),(1,NULL),(2,'a')]

View File

@ -18,6 +18,48 @@ SELECT arrayShuffle(groupArray(x),0xbad_cafe) FROM (SELECT number as x from syst
SELECT arrayShuffle(groupArray(toUInt64(x)),0xbad_cafe) FROM (SELECT number as x from system.numbers LIMIT 100);
SELECT arrayShuffle([tuple(1, -1), tuple(99999999, -99999999), tuple(3, -3)], 0xbad_cafe);
SELECT arrayShuffle([tuple(1, NULL), tuple(2, 'a'), tuple(3, 'A')], 0xbad_cafe);
SELECT arrayPartialShuffle([]); -- trivial cases (equivalent to arrayShuffle)
SELECT arrayPartialShuffle([], 0);
SELECT arrayPartialShuffle([], 0, 0xbad_cafe);
SELECT arrayPartialShuffle([9223372036854775808]);
SELECT arrayPartialShuffle([9223372036854775808], 0);
SELECT arrayPartialShuffle([9223372036854775808], 0, 0xbad_cafe);
SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10], 0, 0xbad_cafe);
SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10.1], 0, 0xbad_cafe);
SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,9223372036854775808], 0, 0xbad_cafe);
SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,NULL], 0, 0xbad_cafe);
SELECT arrayPartialShuffle([toFixedString('123', 3), toFixedString('456', 3), toFixedString('789', 3), toFixedString('ABC', 3), toFixedString('000', 3)], 0, 0xbad_cafe);
SELECT arrayPartialShuffle([toFixedString('123', 3), toFixedString('456', 3), toFixedString('789', 3), toFixedString('ABC', 3), NULL], 0, 0xbad_cafe);
SELECT arrayPartialShuffle(['storage','tiger','imposter','terminal','uniform','sensation'], 0, 0xbad_cafe);
SELECT arrayPartialShuffle(['storage','tiger',NULL,'terminal','uniform','sensation'], 0, 0xbad_cafe);
SELECT arrayPartialShuffle([NULL]);
SELECT arrayPartialShuffle([NULL,NULL]);
SELECT arrayPartialShuffle([[1,2,3,4],[-1,-2,-3,-4],[10,20,30,40],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64]], 0, 0xbad_cafe);
SELECT arrayPartialShuffle([[1,2,3,4],[NULL,-2,-3,-4],[10,20,30,40],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64]], 0, 0xbad_cafe);
SELECT arrayPartialShuffle(groupArray(x),0,0xbad_cafe) FROM (SELECT number as x from system.numbers LIMIT 100);
SELECT arrayPartialShuffle(groupArray(toUInt64(x)),0,0xbad_cafe) FROM (SELECT number as x from system.numbers LIMIT 100);
SELECT arrayPartialShuffle([tuple(1, -1), tuple(99999999, -99999999), tuple(3, -3)], 0, 0xbad_cafe);
SELECT arrayPartialShuffle([tuple(1, NULL), tuple(2, 'a'), tuple(3, 'A')], 0, 0xbad_cafe);
SELECT arrayPartialShuffle([NULL,NULL,NULL], 2); -- other, mostly non-trivial cases
SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10], 1, 0xbad_cafe);
SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10], 2, 0xbad_cafe);
SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10], 4, 0xbad_cafe);
SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10], 8, 0xbad_cafe);
SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10], 9, 0xbad_cafe);
SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10], 10, 0xbad_cafe);
SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,10.1], 4, 0xbad_cafe);
SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,9223372036854775808], 4, 0xbad_cafe);
SELECT arrayPartialShuffle([1,2,3,4,5,6,7,8,9,NULL], 4, 0xbad_cafe);
SELECT arrayPartialShuffle([toFixedString('123', 3), toFixedString('456', 3), toFixedString('789', 3), toFixedString('ABC', 3), toFixedString('000', 3)], 3, 0xbad_cafe);
SELECT arrayPartialShuffle([toFixedString('123', 3), toFixedString('456', 3), toFixedString('789', 3), toFixedString('ABC', 3), NULL], 3, 0xbad_cafe);
SELECT arrayPartialShuffle(['storage','tiger','imposter','terminal','uniform','sensation'], 3, 0xbad_cafe);
SELECT arrayPartialShuffle(['storage','tiger',NULL,'terminal','uniform','sensation'], 3, 0xbad_cafe);
SELECT arrayPartialShuffle([[1,2,3,4],[-1,-2,-3,-4],[10,20,30,40],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64]], 2, 0xbad_cafe);
SELECT arrayPartialShuffle([[1,2,3,4],[NULL,-2,-3,-4],[10,20,30,40],[100,200,300,400,500,600,700,800,900],[2,4,8,16,32,64]], 2, 0xbad_cafe);
SELECT arrayPartialShuffle(groupArray(x),20,0xbad_cafe) FROM (SELECT number as x from system.numbers LIMIT 100);
SELECT arrayPartialShuffle(groupArray(toUInt64(x)),20,0xbad_cafe) FROM (SELECT number as x from system.numbers LIMIT 100);
SELECT arrayPartialShuffle([tuple(1, -1), tuple(99999999, -99999999), tuple(3, -3)], 2, 0xbad_cafe);
SELECT arrayPartialShuffle([tuple(1, NULL), tuple(2, 'a'), tuple(3, 'A')], 2, 0xbad_cafe);
SELECT arrayShuffle(1); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayShuffle([1], 'a'); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT arrayShuffle([1], 1.1); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }