mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-23 08:02:02 +00:00
Merge pull request #54391 from itayisraelov/israelov/generate-random-int-array
Add function `arrayRandomSample()`
This commit is contained in:
commit
32a77ca1eb
@ -2118,6 +2118,80 @@ Result:
|
||||
└─────────────────────┘
|
||||
```
|
||||
|
||||
|
||||
## arrayRandomSample
|
||||
|
||||
Function `arrayRandomSample` returns a subset with `samples`-many random elements of an input array. If `samples` exceeds the size of the input array, the sample size is limited to the size of the array. In this case, all elements of the input array are returned, but the order is not guaranteed. The function can handle both flat arrays and nested arrays.
|
||||
|
||||
**Syntax**
|
||||
|
||||
```sql
|
||||
arrayRandomSample(arr, samples)
|
||||
```
|
||||
|
||||
**Arguments**
|
||||
|
||||
- `arr` — The input array from which to sample elements. This may be flat or nested arrays.
|
||||
- `samples` — An unsigned integer specifying the number of elements to include in the random sample.
|
||||
|
||||
**Returned Value**
|
||||
|
||||
- An array containing a random sample of elements from the input array.
|
||||
|
||||
**Examples**
|
||||
|
||||
Query:
|
||||
|
||||
```sql
|
||||
SELECT arrayRandomSample(['apple', 'banana', 'cherry', 'date'], 2) as res;
|
||||
```
|
||||
|
||||
Result:
|
||||
```
|
||||
┌─res────────────────┐
|
||||
│ ['banana','apple'] │
|
||||
└────────────────────┘
|
||||
```
|
||||
|
||||
Query:
|
||||
|
||||
```sql
|
||||
SELECT arrayRandomSample([[1, 2], [3, 4], [5, 6]], 2) as res;
|
||||
```
|
||||
|
||||
Result:
|
||||
```
|
||||
┌─res───────────┐
|
||||
│ [[3,4],[5,6]] │
|
||||
└───────────────┘
|
||||
```
|
||||
|
||||
Query:
|
||||
|
||||
```sql
|
||||
SELECT arrayRandomSample([1, 2, 3, 4, 5], 0) as res;
|
||||
```
|
||||
|
||||
Result:
|
||||
```
|
||||
┌─res─┐
|
||||
│ [] │
|
||||
└─────┘
|
||||
```
|
||||
|
||||
Query:
|
||||
|
||||
```sql
|
||||
SELECT arrayRandomSample([1, 2, 3], 5) as res;
|
||||
```
|
||||
|
||||
Result:
|
||||
```
|
||||
┌─res─────┐
|
||||
│ [3,1,2] │
|
||||
└─────────┘
|
||||
```
|
||||
|
||||
## Distance functions
|
||||
|
||||
All supported functions are described in [distance functions documentation](../../sql-reference/functions/distance-functions.md).
|
||||
|
118
src/Functions/array/arrayRandomSample.cpp
Normal file
118
src/Functions/array/arrayRandomSample.cpp
Normal file
@ -0,0 +1,118 @@
|
||||
#include <random>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <Functions/IFunction.h>
|
||||
#include <Poco/Logger.h>
|
||||
#include "Columns/ColumnsNumber.h"
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_COLUMN;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
}
|
||||
|
||||
/// arrayRandomSample(arr, k) - Returns k random elements from the input array
|
||||
class FunctionArrayRandomSample : public IFunction
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "arrayRandomSample";
|
||||
|
||||
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayRandomSample>(); }
|
||||
|
||||
String getName() const override { return name; }
|
||||
|
||||
size_t getNumberOfArguments() const override { return 2; }
|
||||
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
|
||||
bool useDefaultImplementationForConstants() const override { return true; }
|
||||
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
|
||||
{
|
||||
FunctionArgumentDescriptors args{
|
||||
{"array", &isArray<IDataType>, nullptr, "Array"},
|
||||
{"samples", &isUnsignedInteger<IDataType>, isColumnConst, "const UInt*"},
|
||||
};
|
||||
validateFunctionArgumentTypes(*this, arguments, args);
|
||||
|
||||
// Return an array with the same nested type as the input array
|
||||
const DataTypePtr & array_type = arguments[0].type;
|
||||
const DataTypeArray * array_data_type = checkAndGetDataType<DataTypeArray>(array_type.get());
|
||||
|
||||
// Get the nested data type of the array
|
||||
const DataTypePtr & nested_type = array_data_type->getNestedType();
|
||||
|
||||
return std::make_shared<DataTypeArray>(nested_type);
|
||||
}
|
||||
|
||||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
|
||||
{
|
||||
const ColumnArray * column_array = checkAndGetColumn<ColumnArray>(arguments[0].column.get());
|
||||
if (!column_array)
|
||||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument must be an array");
|
||||
|
||||
const IColumn * col_samples = arguments[1].column.get();
|
||||
if (!col_samples)
|
||||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "The second argument is empty or null, type = {}", arguments[1].type->getName());
|
||||
|
||||
UInt64 samples;
|
||||
try
|
||||
{
|
||||
samples = col_samples->getUInt(0);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Failed to fetch UInt64 from the second argument column, type = {}",
|
||||
arguments[1].type->getName());
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
|
||||
auto nested_column = column_array->getDataPtr()->cloneEmpty();
|
||||
auto offsets_column = ColumnUInt64::create();
|
||||
|
||||
auto res_data = ColumnArray::create(std::move(nested_column), std::move(offsets_column));
|
||||
|
||||
const auto & input_offsets = column_array->getOffsets();
|
||||
auto & res_offsets = res_data->getOffsets();
|
||||
res_offsets.resize(input_rows_count);
|
||||
|
||||
UInt64 cur_samples;
|
||||
size_t current_offset = 0;
|
||||
|
||||
for (size_t row = 0; row < input_rows_count; row++)
|
||||
{
|
||||
size_t row_size = input_offsets[row] - current_offset;
|
||||
|
||||
std::vector<size_t> indices(row_size);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
std::shuffle(indices.begin(), indices.end(), gen);
|
||||
|
||||
cur_samples = std::min(samples, static_cast<UInt64>(row_size));
|
||||
|
||||
for (UInt64 j = 0; j < cur_samples; j++)
|
||||
{
|
||||
size_t source_index = indices[j];
|
||||
res_data->getData().insertFrom(column_array->getData(), source_index);
|
||||
}
|
||||
|
||||
res_offsets[row] = current_offset + cur_samples;
|
||||
current_offset += cur_samples;
|
||||
}
|
||||
|
||||
return res_data;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_FUNCTION(ArrayRandomSample)
|
||||
{
|
||||
factory.registerFunction<FunctionArrayRandomSample>();
|
||||
}
|
||||
|
||||
}
|
@ -126,6 +126,7 @@ arrayPopFront
|
||||
arrayProduct
|
||||
arrayPushBack
|
||||
arrayPushFront
|
||||
arrayRandomSample
|
||||
arrayReduce
|
||||
arrayReduceInRanges
|
||||
arrayResize
|
||||
|
@ -0,0 +1,37 @@
|
||||
Running iteration: 1
|
||||
Integer Test: Passed
|
||||
String Test: Passed
|
||||
Nested Array Test: Passed
|
||||
Higher Sample Number Test: Passed
|
||||
Multi-row Test with scalar k: Passed
|
||||
Running iteration: 2
|
||||
Integer Test: Passed
|
||||
String Test: Passed
|
||||
Nested Array Test: Passed
|
||||
Higher Sample Number Test: Passed
|
||||
Multi-row Test with scalar k: Passed
|
||||
Running iteration: 3
|
||||
Integer Test: Passed
|
||||
String Test: Passed
|
||||
Nested Array Test: Passed
|
||||
Higher Sample Number Test: Passed
|
||||
Multi-row Test with scalar k: Passed
|
||||
Running iteration: 4
|
||||
Integer Test: Passed
|
||||
String Test: Passed
|
||||
Nested Array Test: Passed
|
||||
Higher Sample Number Test: Passed
|
||||
Multi-row Test with scalar k: Passed
|
||||
Running iteration: 5
|
||||
Integer Test: Passed
|
||||
String Test: Passed
|
||||
Nested Array Test: Passed
|
||||
Higher Sample Number Test: Passed
|
||||
Multi-row Test with scalar k: Passed
|
||||
Integer Test with K=0: Passed
|
||||
Empty Array with K > 0 Test: Passed
|
||||
Non-Unsigned-Integer K Test (Negative Integer): Passed
|
||||
Non-Unsigned-Integer K Test (String): Passed
|
||||
Non-Unsigned-Integer K Test (Floating-Point): Passed
|
||||
Total tests: 30
|
||||
Passed tests: 30
|
258
tests/queries/0_stateless/02874_array_random_sample.sh
Executable file
258
tests/queries/0_stateless/02874_array_random_sample.sh
Executable file
@ -0,0 +1,258 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
CUR_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
|
||||
# shellcheck source=../shell_config.sh
|
||||
. "$CUR_DIR"/../shell_config.sh
|
||||
|
||||
# Initialize variables
|
||||
total_tests=0
|
||||
passed_tests=0
|
||||
|
||||
|
||||
# Test Function for Integer Arrays
|
||||
run_integer_test() {
|
||||
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([1,2,3], 2)")
|
||||
mapfile -t sorted_result < <(echo "$query_result" | tr -d '[]' | tr ',' '\n' | sort -n)
|
||||
declare -A expected_outcomes
|
||||
expected_outcomes["1 2"]=1
|
||||
expected_outcomes["1 3"]=1
|
||||
expected_outcomes["2 3"]=1
|
||||
expected_outcomes["2 1"]=1
|
||||
expected_outcomes["3 1"]=1
|
||||
expected_outcomes["3 2"]=1
|
||||
|
||||
sorted_result_str=$(echo "${sorted_result[*]}" | tr ' ' '\n' | sort -n | tr '\n' ' ' | sed 's/ $//')
|
||||
if [[ -n "${expected_outcomes[$sorted_result_str]}" ]]; then
|
||||
echo "Integer Test: Passed"
|
||||
((passed_tests++))
|
||||
else
|
||||
echo "Integer Test: Failed"
|
||||
echo "Output: $query_result"
|
||||
fi
|
||||
((total_tests++))
|
||||
}
|
||||
|
||||
# Test Function for String Arrays
|
||||
run_string_test() {
|
||||
query_result=$(clickhouse-client -q "SELECT arrayRandomSample(['a','b','c'], 2)")
|
||||
mapfile -t sorted_result < <(echo "$query_result" | tr -d "[]'" | tr ',' '\n' | sort)
|
||||
declare -A expected_outcomes
|
||||
expected_outcomes["a b"]=1
|
||||
expected_outcomes["a c"]=1
|
||||
expected_outcomes["b c"]=1
|
||||
expected_outcomes["b a"]=1
|
||||
expected_outcomes["c a"]=1
|
||||
expected_outcomes["c b"]=1
|
||||
|
||||
sorted_result_str=$(echo "${sorted_result[*]}" | tr ' ' '\n' | sort | tr '\n' ' ' | sed 's/ $//')
|
||||
if [[ -n "${expected_outcomes[$sorted_result_str]}" ]]; then
|
||||
echo "String Test: Passed"
|
||||
((passed_tests++))
|
||||
else
|
||||
echo "String Test: Failed"
|
||||
echo "Output: $query_result"
|
||||
fi
|
||||
((total_tests++))
|
||||
}
|
||||
|
||||
# Test Function for Nested Arrays
|
||||
run_nested_array_test() {
|
||||
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([[7,2],[3,4],[7,6]], 2)")
|
||||
# Convert to a space-separated string for easy sorting.
|
||||
converted_result=$(echo "$query_result" | tr -d '[]' | tr ',' ' ')
|
||||
|
||||
# Sort the string.
|
||||
sorted_result_str=$(echo "$converted_result" | tr ' ' '\n' | xargs -n2 | sort | tr '\n' ' ' | sed 's/ $//')
|
||||
|
||||
# Define all possible expected outcomes, sorted
|
||||
declare -A expected_outcomes
|
||||
expected_outcomes["7 2 3 4"]=1
|
||||
expected_outcomes["7 2 7 6"]=1
|
||||
expected_outcomes["3 4 7 6"]=1
|
||||
expected_outcomes["3 4 7 2"]=1
|
||||
expected_outcomes["7 6 7 2"]=1
|
||||
expected_outcomes["7 6 3 4"]=1
|
||||
|
||||
if [[ -n "${expected_outcomes[$sorted_result_str]}" ]]; then
|
||||
echo "Nested Array Test: Passed"
|
||||
((passed_tests++))
|
||||
else
|
||||
echo "Nested Array Test: Failed"
|
||||
echo "Output: $query_result"
|
||||
echo "Processed Output: ${sorted_result_str}"
|
||||
fi
|
||||
((total_tests++))
|
||||
}
|
||||
|
||||
|
||||
# Test Function for K > array.size
|
||||
run_higher_k_test() {
|
||||
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([1,2,3], 5)")
|
||||
mapfile -t sorted_result < <(echo "$query_result" | tr -d '[]' | tr ',' '\n' | sort -n)
|
||||
sorted_original=("1" "2" "3")
|
||||
|
||||
are_arrays_equal=true
|
||||
for i in "${!sorted_result[@]}"; do
|
||||
if [[ "${sorted_result[$i]}" != "${sorted_original[$i]}" ]]; then
|
||||
are_arrays_equal=false
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if $are_arrays_equal; then
|
||||
echo "Higher Sample Number Test: Passed"
|
||||
((passed_tests++))
|
||||
else
|
||||
echo "Higher Sample Number Test: Failed"
|
||||
echo "Output: $query_result"
|
||||
fi
|
||||
((total_tests++))
|
||||
}
|
||||
|
||||
# Test Function for Integer Arrays with samples = 0
|
||||
run_integer_with_samples_0_test() {
|
||||
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([1,2,3], 0)")
|
||||
mapfile -t sorted_result < <(echo "$query_result" | tr -d '[]' | tr ',' '\n' | sort -n)
|
||||
|
||||
# An empty array should produce an empty string after transformations
|
||||
declare -A expected_outcomes
|
||||
expected_outcomes["EMPTY_ARRAY"]=1
|
||||
|
||||
# Prepare the result string for comparison
|
||||
sorted_result_str=$(echo "${sorted_result[*]}" | tr ' ' '\n' | sort -n | tr '\n' ' ' | sed 's/ $//')
|
||||
|
||||
# Use "EMPTY_ARRAY" as a placeholder for an empty array
|
||||
[[ -z "$sorted_result_str" ]] && sorted_result_str="EMPTY_ARRAY"
|
||||
|
||||
# Compare
|
||||
if [[ -n "${expected_outcomes[$sorted_result_str]}" ]]; then
|
||||
echo "Integer Test with K=0: Passed"
|
||||
((passed_tests++))
|
||||
else
|
||||
echo "Integer Test with K=0: Failed"
|
||||
echo "Output: $query_result"
|
||||
fi
|
||||
((total_tests++))
|
||||
}
|
||||
|
||||
# Test Function for Empty Array with K > 0
|
||||
run_empty_array_with_k_test() {
|
||||
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([], 5)")
|
||||
|
||||
if [[ "$query_result" == "[]" ]]; then
|
||||
echo "Empty Array with K > 0 Test: Passed"
|
||||
((passed_tests++))
|
||||
else {
|
||||
echo "Empty Array with K > 0 Test: Failed"
|
||||
echo "Output: $query_result"
|
||||
}
|
||||
fi
|
||||
((total_tests++))
|
||||
}
|
||||
|
||||
# Test Function for Non-Unsigned-Integer K
|
||||
run_non_unsigned_integer_k_test() {
|
||||
# Test with negative integer
|
||||
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([1, 2, 3], -5)" 2>&1)
|
||||
if [[ "$query_result" == *"ILLEGAL_TYPE_OF_ARGUMENT"* ]]; then
|
||||
echo "Non-Unsigned-Integer K Test (Negative Integer): Passed"
|
||||
((passed_tests++))
|
||||
else {
|
||||
echo "Non-Unsigned-Integer K Test (Negative Integer): Failed"
|
||||
echo "Output: $query_result"
|
||||
}
|
||||
fi
|
||||
((total_tests++))
|
||||
|
||||
# Test with string
|
||||
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([1, 2, 3], 'a')" 2>&1)
|
||||
if [[ "$query_result" == *"ILLEGAL_TYPE_OF_ARGUMENT"* ]]; then
|
||||
echo "Non-Unsigned-Integer K Test (String): Passed"
|
||||
((passed_tests++))
|
||||
else {
|
||||
echo "Non-Unsigned-Integer K Test (String): Failed"
|
||||
echo "Output: $query_result"
|
||||
}
|
||||
fi
|
||||
((total_tests++))
|
||||
|
||||
# Test with floating-point number
|
||||
query_result=$(clickhouse-client -q "SELECT arrayRandomSample([1, 2, 3], 1.5)" 2>&1)
|
||||
if [[ "$query_result" == *"ILLEGAL_TYPE_OF_ARGUMENT"* ]]; then
|
||||
echo "Non-Unsigned-Integer K Test (Floating-Point): Passed"
|
||||
((passed_tests++))
|
||||
else {
|
||||
echo "Non-Unsigned-Integer K Test (Floating-Point): Failed"
|
||||
echo "Output: $query_result"
|
||||
}
|
||||
fi
|
||||
((total_tests++))
|
||||
}
|
||||
|
||||
# Function to run a multi-row test with scalar 'k'
|
||||
run_multi_row_scalar_k_test() {
|
||||
# Create a table. Use a random database name as tests potentially run in parallel.
|
||||
db=`tr -dc A-Za-z0-9 </dev/urandom | head -c 13`
|
||||
clickhouse-client -q "DROP DATABASE IF EXISTS ${db}"
|
||||
clickhouse-client -q "CREATE DATABASE ${db}"
|
||||
clickhouse-client -q "CREATE TABLE ${db}.array_test (arr Array(Int32)) ENGINE = Memory"
|
||||
|
||||
# Insert multi-row data into the table
|
||||
clickhouse-client -q "INSERT INTO ${db}.array_test VALUES ([1, 2, 3]), ([4, 5, 6]), ([7, 8, 9])"
|
||||
|
||||
# Query using arrayRandomSample function and store the result, k is scalar here (for example, 2)
|
||||
query_result=$(clickhouse-client -q "SELECT arrayRandomSample(arr, 2) FROM ${db}.array_test")
|
||||
|
||||
# Drop the table
|
||||
clickhouse-client -q "DROP DATABASE ${db}"
|
||||
|
||||
# Validate the output here
|
||||
is_test_passed=1 # flag to indicate if the test passed; 1 means passed, 0 means failed
|
||||
|
||||
# Iterate over each line (each array) in the output
|
||||
echo "$query_result" | while read -r line; do
|
||||
# Remove brackets from the array string
|
||||
line=$(echo "$line" | tr -d '[]')
|
||||
|
||||
# Convert the string to an array
|
||||
IFS=", " read -ra nums <<< "$line"
|
||||
|
||||
# Check if the array contains exactly 2 unique elements
|
||||
if [[ ${#nums[@]} -ne 2 ]] || [[ ${nums[0]} -eq ${nums[1]} ]]; then
|
||||
# shellcheck disable=SC2030
|
||||
is_test_passed=0
|
||||
fi
|
||||
done
|
||||
|
||||
# Print test result
|
||||
# shellcheck disable=SC2031
|
||||
if [[ $is_test_passed -eq 1 ]]; then
|
||||
echo "Multi-row Test with scalar k: Passed"
|
||||
((passed_tests++))
|
||||
else
|
||||
echo "Multi-row Test with scalar k: Failed"
|
||||
echo "Output: $query_result"
|
||||
fi
|
||||
|
||||
((total_tests++))
|
||||
}
|
||||
|
||||
|
||||
|
||||
# Run test multiple times
|
||||
for i in {1..5}; do
|
||||
echo "Running iteration: $i"
|
||||
run_integer_test
|
||||
run_string_test
|
||||
run_nested_array_test
|
||||
run_higher_k_test
|
||||
run_multi_row_scalar_k_test
|
||||
done
|
||||
|
||||
run_integer_with_samples_0_test
|
||||
run_empty_array_with_k_test
|
||||
run_non_unsigned_integer_k_test
|
||||
|
||||
# Print overall test results
|
||||
echo "Total tests: $total_tests"
|
||||
echo "Passed tests: $passed_tests"
|
@ -1061,6 +1061,7 @@ arrayPopFront
|
||||
arrayProduct
|
||||
arrayPushBack
|
||||
arrayPushFront
|
||||
arrayRandomSample
|
||||
arrayReduce
|
||||
arrayReduceInRanges
|
||||
arrayResize
|
||||
|
Loading…
Reference in New Issue
Block a user