Merge pull request #54391 from itayisraelov/israelov/generate-random-int-array

Add function `arrayRandomSample()`
This commit is contained in:
Robert Schulze 2023-10-08 18:28:58 +02:00 committed by GitHub
commit 32a77ca1eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 489 additions and 0 deletions

View File

@ -2118,6 +2118,80 @@ Result:
└─────────────────────┘ └─────────────────────┘
``` ```
## arrayRandomSample
Function `arrayRandomSample` returns a subset with `samples`-many random elements of an input array. If `samples` exceeds the size of the input array, the sample size is limited to the size of the array. In this case, all elements of the input array are returned, but the order is not guaranteed. The function can handle both flat arrays and nested arrays.
**Syntax**
```sql
arrayRandomSample(arr, samples)
```
**Arguments**
- `arr` — The input array from which to sample elements. This may be flat or nested arrays.
- `samples` — An unsigned integer specifying the number of elements to include in the random sample.
**Returned Value**
- An array containing a random sample of elements from the input array.
**Examples**
Query:
```sql
SELECT arrayRandomSample(['apple', 'banana', 'cherry', 'date'], 2) as res;
```
Result:
```
┌─res────────────────┐
│ ['banana','apple'] │
└────────────────────┘
```
Query:
```sql
SELECT arrayRandomSample([[1, 2], [3, 4], [5, 6]], 2) as res;
```
Result:
```
┌─res───────────┐
│ [[3,4],[5,6]] │
└───────────────┘
```
Query:
```sql
SELECT arrayRandomSample([1, 2, 3, 4, 5], 0) as res;
```
Result:
```
┌─res─┐
│ [] │
└─────┘
```
Query:
```sql
SELECT arrayRandomSample([1, 2, 3], 5) as res;
```
Result:
```
┌─res─────┐
│ [3,1,2] │
└─────────┘
```
## Distance functions ## Distance functions
All supported functions are described in [distance functions documentation](../../sql-reference/functions/distance-functions.md). All supported functions are described in [distance functions documentation](../../sql-reference/functions/distance-functions.md).

View File

@ -0,0 +1,118 @@
#include <random>
#include <Columns/ColumnArray.h>
#include <DataTypes/DataTypeArray.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Poco/Logger.h>
#include "Columns/ColumnsNumber.h"
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
/// arrayRandomSample(arr, k) - Returns k random elements from the input array
class FunctionArrayRandomSample : public IFunction
{
public:
static constexpr auto name = "arrayRandomSample";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayRandomSample>(); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
bool useDefaultImplementationForConstants() const override { return true; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
FunctionArgumentDescriptors args{
{"array", &isArray<IDataType>, nullptr, "Array"},
{"samples", &isUnsignedInteger<IDataType>, isColumnConst, "const UInt*"},
};
validateFunctionArgumentTypes(*this, arguments, args);
// Return an array with the same nested type as the input array
const DataTypePtr & array_type = arguments[0].type;
const DataTypeArray * array_data_type = checkAndGetDataType<DataTypeArray>(array_type.get());
// Get the nested data type of the array
const DataTypePtr & nested_type = array_data_type->getNestedType();
return std::make_shared<DataTypeArray>(nested_type);
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const ColumnArray * column_array = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
if (!column_array)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument must be an array");
const IColumn * col_samples = arguments[1].column.get();
if (!col_samples)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "The second argument is empty or null, type = {}", arguments[1].type->getName());
UInt64 samples;
try
{
samples = col_samples->getUInt(0);
}
catch (...)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Failed to fetch UInt64 from the second argument column, type = {}",
arguments[1].type->getName());
}
std::random_device rd;
std::mt19937 gen(rd());
auto nested_column = column_array->getDataPtr()->cloneEmpty();
auto offsets_column = ColumnUInt64::create();
auto res_data = ColumnArray::create(std::move(nested_column), std::move(offsets_column));
const auto & input_offsets = column_array->getOffsets();
auto & res_offsets = res_data->getOffsets();
res_offsets.resize(input_rows_count);
UInt64 cur_samples;
size_t current_offset = 0;
for (size_t row = 0; row < input_rows_count; row++)
{
size_t row_size = input_offsets[row] - current_offset;
std::vector<size_t> indices(row_size);
std::iota(indices.begin(), indices.end(), 0);
std::shuffle(indices.begin(), indices.end(), gen);
cur_samples = std::min(samples, static_cast<UInt64>(row_size));
for (UInt64 j = 0; j < cur_samples; j++)
{
size_t source_index = indices[j];
res_data->getData().insertFrom(column_array->getData(), source_index);
}
res_offsets[row] = current_offset + cur_samples;
current_offset += cur_samples;
}
return res_data;
}
};
REGISTER_FUNCTION(ArrayRandomSample)
{
factory.registerFunction<FunctionArrayRandomSample>();
}
}

View File

@ -126,6 +126,7 @@ arrayPopFront
arrayProduct arrayProduct
arrayPushBack arrayPushBack
arrayPushFront arrayPushFront
arrayRandomSample
arrayReduce arrayReduce
arrayReduceInRanges arrayReduceInRanges
arrayResize arrayResize

View File

@ -0,0 +1,37 @@
Running iteration: 1
Integer Test: Passed
String Test: Passed
Nested Array Test: Passed
Higher Sample Number Test: Passed
Multi-row Test with scalar k: Passed
Running iteration: 2
Integer Test: Passed
String Test: Passed
Nested Array Test: Passed
Higher Sample Number Test: Passed
Multi-row Test with scalar k: Passed
Running iteration: 3
Integer Test: Passed
String Test: Passed
Nested Array Test: Passed
Higher Sample Number Test: Passed
Multi-row Test with scalar k: Passed
Running iteration: 4
Integer Test: Passed
String Test: Passed
Nested Array Test: Passed
Higher Sample Number Test: Passed
Multi-row Test with scalar k: Passed
Running iteration: 5
Integer Test: Passed
String Test: Passed
Nested Array Test: Passed
Higher Sample Number Test: Passed
Multi-row Test with scalar k: Passed
Integer Test with K=0: Passed
Empty Array with K > 0 Test: Passed
Non-Unsigned-Integer K Test (Negative Integer): Passed
Non-Unsigned-Integer K Test (String): Passed
Non-Unsigned-Integer K Test (Floating-Point): Passed
Total tests: 30
Passed tests: 30

View File

@ -0,0 +1,258 @@
#!/usr/bin/env bash
CUR_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=../shell_config.sh
. "$CUR_DIR"/../shell_config.sh
# Initialize variables
total_tests=0
passed_tests=0
# Test Function for Integer Arrays
run_integer_test() {
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([1,2,3], 2)")
mapfile -t sorted_result < <(echo "$query_result" | tr -d '[]' | tr ',' '\n' | sort -n)
declare -A expected_outcomes
expected_outcomes["1 2"]=1
expected_outcomes["1 3"]=1
expected_outcomes["2 3"]=1
expected_outcomes["2 1"]=1
expected_outcomes["3 1"]=1
expected_outcomes["3 2"]=1
sorted_result_str=$(echo "${sorted_result[*]}" | tr ' ' '\n' | sort -n | tr '\n' ' ' | sed 's/ $//')
if [[ -n "${expected_outcomes[$sorted_result_str]}" ]]; then
echo "Integer Test: Passed"
((passed_tests++))
else
echo "Integer Test: Failed"
echo "Output: $query_result"
fi
((total_tests++))
}
# Test Function for String Arrays
run_string_test() {
query_result=$(clickhouse-client -q "SELECT arrayRandomSample(['a','b','c'], 2)")
mapfile -t sorted_result < <(echo "$query_result" | tr -d "[]'" | tr ',' '\n' | sort)
declare -A expected_outcomes
expected_outcomes["a b"]=1
expected_outcomes["a c"]=1
expected_outcomes["b c"]=1
expected_outcomes["b a"]=1
expected_outcomes["c a"]=1
expected_outcomes["c b"]=1
sorted_result_str=$(echo "${sorted_result[*]}" | tr ' ' '\n' | sort | tr '\n' ' ' | sed 's/ $//')
if [[ -n "${expected_outcomes[$sorted_result_str]}" ]]; then
echo "String Test: Passed"
((passed_tests++))
else
echo "String Test: Failed"
echo "Output: $query_result"
fi
((total_tests++))
}
# Test Function for Nested Arrays
run_nested_array_test() {
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([[7,2],[3,4],[7,6]], 2)")
# Convert to a space-separated string for easy sorting.
converted_result=$(echo "$query_result" | tr -d '[]' | tr ',' ' ')
# Sort the string.
sorted_result_str=$(echo "$converted_result" | tr ' ' '\n' | xargs -n2 | sort | tr '\n' ' ' | sed 's/ $//')
# Define all possible expected outcomes, sorted
declare -A expected_outcomes
expected_outcomes["7 2 3 4"]=1
expected_outcomes["7 2 7 6"]=1
expected_outcomes["3 4 7 6"]=1
expected_outcomes["3 4 7 2"]=1
expected_outcomes["7 6 7 2"]=1
expected_outcomes["7 6 3 4"]=1
if [[ -n "${expected_outcomes[$sorted_result_str]}" ]]; then
echo "Nested Array Test: Passed"
((passed_tests++))
else
echo "Nested Array Test: Failed"
echo "Output: $query_result"
echo "Processed Output: ${sorted_result_str}"
fi
((total_tests++))
}
# Test Function for K > array.size
run_higher_k_test() {
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([1,2,3], 5)")
mapfile -t sorted_result < <(echo "$query_result" | tr -d '[]' | tr ',' '\n' | sort -n)
sorted_original=("1" "2" "3")
are_arrays_equal=true
for i in "${!sorted_result[@]}"; do
if [[ "${sorted_result[$i]}" != "${sorted_original[$i]}" ]]; then
are_arrays_equal=false
break
fi
done
if $are_arrays_equal; then
echo "Higher Sample Number Test: Passed"
((passed_tests++))
else
echo "Higher Sample Number Test: Failed"
echo "Output: $query_result"
fi
((total_tests++))
}
# Test Function for Integer Arrays with samples = 0
run_integer_with_samples_0_test() {
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([1,2,3], 0)")
mapfile -t sorted_result < <(echo "$query_result" | tr -d '[]' | tr ',' '\n' | sort -n)
# An empty array should produce an empty string after transformations
declare -A expected_outcomes
expected_outcomes["EMPTY_ARRAY"]=1
# Prepare the result string for comparison
sorted_result_str=$(echo "${sorted_result[*]}" | tr ' ' '\n' | sort -n | tr '\n' ' ' | sed 's/ $//')
# Use "EMPTY_ARRAY" as a placeholder for an empty array
[[ -z "$sorted_result_str" ]] && sorted_result_str="EMPTY_ARRAY"
# Compare
if [[ -n "${expected_outcomes[$sorted_result_str]}" ]]; then
echo "Integer Test with K=0: Passed"
((passed_tests++))
else
echo "Integer Test with K=0: Failed"
echo "Output: $query_result"
fi
((total_tests++))
}
# Test Function for Empty Array with K > 0
run_empty_array_with_k_test() {
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([], 5)")
if [[ "$query_result" == "[]" ]]; then
echo "Empty Array with K > 0 Test: Passed"
((passed_tests++))
else {
echo "Empty Array with K > 0 Test: Failed"
echo "Output: $query_result"
}
fi
((total_tests++))
}
# Test Function for Non-Unsigned-Integer K
run_non_unsigned_integer_k_test() {
# Test with negative integer
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([1, 2, 3], -5)" 2>&1)
if [[ "$query_result" == *"ILLEGAL_TYPE_OF_ARGUMENT"* ]]; then
echo "Non-Unsigned-Integer K Test (Negative Integer): Passed"
((passed_tests++))
else {
echo "Non-Unsigned-Integer K Test (Negative Integer): Failed"
echo "Output: $query_result"
}
fi
((total_tests++))
# Test with string
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([1, 2, 3], 'a')" 2>&1)
if [[ "$query_result" == *"ILLEGAL_TYPE_OF_ARGUMENT"* ]]; then
echo "Non-Unsigned-Integer K Test (String): Passed"
((passed_tests++))
else {
echo "Non-Unsigned-Integer K Test (String): Failed"
echo "Output: $query_result"
}
fi
((total_tests++))
# Test with floating-point number
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([1, 2, 3], 1.5)" 2>&1)
if [[ "$query_result" == *"ILLEGAL_TYPE_OF_ARGUMENT"* ]]; then
echo "Non-Unsigned-Integer K Test (Floating-Point): Passed"
((passed_tests++))
else {
echo "Non-Unsigned-Integer K Test (Floating-Point): Failed"
echo "Output: $query_result"
}
fi
((total_tests++))
}
# Function to run a multi-row test with scalar 'k'
run_multi_row_scalar_k_test() {
# Create a table. Use a random database name as tests potentially run in parallel.
db=`tr -dc A-Za-z0-9 </dev/urandom | head -c 13`
clickhouse-client -q "DROP DATABASE IF EXISTS ${db}"
clickhouse-client -q "CREATE DATABASE ${db}"
clickhouse-client -q "CREATE TABLE ${db}.array_test (arr Array(Int32)) ENGINE = Memory"
# Insert multi-row data into the table
clickhouse-client -q "INSERT INTO ${db}.array_test VALUES ([1, 2, 3]), ([4, 5, 6]), ([7, 8, 9])"
# Query using arrayRandomSample function and store the result, k is scalar here (for example, 2)
query_result=$(clickhouse-client -q "SELECT arrayRandomSample(arr, 2) FROM ${db}.array_test")
# Drop the table
clickhouse-client -q "DROP DATABASE ${db}"
# Validate the output here
is_test_passed=1 # flag to indicate if the test passed; 1 means passed, 0 means failed
# Iterate over each line (each array) in the output
echo "$query_result" | while read -r line; do
# Remove brackets from the array string
line=$(echo "$line" | tr -d '[]')
# Convert the string to an array
IFS=", " read -ra nums <<< "$line"
# Check if the array contains exactly 2 unique elements
if [[ ${#nums[@]} -ne 2 ]] || [[ ${nums[0]} -eq ${nums[1]} ]]; then
# shellcheck disable=SC2030
is_test_passed=0
fi
done
# Print test result
# shellcheck disable=SC2031
if [[ $is_test_passed -eq 1 ]]; then
echo "Multi-row Test with scalar k: Passed"
((passed_tests++))
else
echo "Multi-row Test with scalar k: Failed"
echo "Output: $query_result"
fi
((total_tests++))
}
# Run test multiple times
for i in {1..5}; do
echo "Running iteration: $i"
run_integer_test
run_string_test
run_nested_array_test
run_higher_k_test
run_multi_row_scalar_k_test
done
run_integer_with_samples_0_test
run_empty_array_with_k_test
run_non_unsigned_integer_k_test
# Print overall test results
echo "Total tests: $total_tests"
echo "Passed tests: $passed_tests"

View File

@ -1061,6 +1061,7 @@ arrayPopFront
arrayProduct arrayProduct
arrayPushBack arrayPushBack
arrayPushFront arrayPushFront
arrayRandomSample
arrayReduce arrayReduce
arrayReduceInRanges arrayReduceInRanges
arrayResize arrayResize