Merge pull request #68989 from petern48/add_arrayUnion_func

Add support for `arrayUnion()` function
This commit is contained in:
vdimir 2024-09-20 14:40:16 +00:00 committed by GitHub
commit 8ba6acb64c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 274 additions and 75 deletions

View File

@ -1717,6 +1717,24 @@ Result:
[[1,1,2,3],[1,2,3,4]]
```
## arrayUnion(arr)
Takes multiple arrays, returns an array that contains all elements that are present in any of the source arrays.
Example:
```sql
SELECT
arrayUnion([-2, 1], [10, 1], [-2], []) as num_example,
arrayUnion(['hi'], [], ['hello', 'hi']) as str_example,
arrayUnion([1, 3, NULL], [2, 3, NULL]) as null_example
```
```text
┌─num_example─┬─str_example────┬─null_example─┐
│ [10,-2,1] │ ['hello','hi'] │ [3,2,1,NULL] │
└─────────────┴────────────────┴──────────────┘
```
## arrayIntersect(arr)
Takes multiple arrays, returns an array with elements that are present in all source arrays.

View File

@ -1,3 +1,4 @@
#include <type_traits>
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
@ -12,6 +13,7 @@
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/getMostSubtype.h>
#include <DataTypes/getLeastSupertype.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnFixedString.h>
@ -35,10 +37,21 @@ namespace ErrorCodes
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
struct ArrayModeIntersect
{
static constexpr auto name = "arrayIntersect";
};
struct ArrayModeUnion
{
static constexpr auto name = "arrayUnion";
};
template <typename Mode>
class FunctionArrayIntersect : public IFunction
{
public:
static constexpr auto name = "arrayIntersect";
static constexpr auto name = Mode::name;
static FunctionPtr create(ContextPtr context) { return std::make_shared<FunctionArrayIntersect>(context); }
explicit FunctionArrayIntersect(ContextPtr context_) : context(context_) {}
@ -97,6 +110,9 @@ private:
template <typename Map, typename ColumnType, bool is_numeric_column>
static ColumnPtr execute(const UnpackedArrays & arrays, MutableColumnPtr result_data);
template <typename Map, typename ColumnType, bool is_numeric_column>
static void insertElement(typename Map::LookupResult & pair, size_t & result_offset, ColumnType & result_data, NullMap & null_map, const bool & use_null_map);
struct NumberExecutor
{
const UnpackedArrays & arrays;
@ -124,13 +140,15 @@ private:
};
};
DataTypePtr FunctionArrayIntersect::getReturnTypeImpl(const DataTypes & arguments) const
template <typename Mode>
DataTypePtr FunctionArrayIntersect<Mode>::getReturnTypeImpl(const DataTypes & arguments) const
{
DataTypes nested_types;
nested_types.reserve(arguments.size());
bool has_nothing = false;
DataTypePtr has_decimal_type = nullptr;
DataTypePtr has_non_decimal_type = nullptr;
if (arguments.empty())
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least one argument.", getName());
@ -146,23 +164,49 @@ DataTypePtr FunctionArrayIntersect::getReturnTypeImpl(const DataTypes & argument
const auto & nested_type = array_type->getNestedType();
if (typeid_cast<const DataTypeNothing *>(nested_type.get()))
has_nothing = true;
{
if constexpr (std::is_same_v<Mode, ArrayModeIntersect>)
{
has_nothing = true;
break;
}
}
else
{
nested_types.push_back(nested_type);
/// Throw exception if have a decimal and another type (e.g int/date type)
/// This is the same behavior as the arrayIntersect and notEquals functions
/// This case is not covered by getLeastSupertype() and results in crashing the program if left out
if constexpr (std::is_same_v<Mode, ArrayModeUnion>)
{
if (WhichDataType(nested_type).isDecimal())
has_decimal_type = nested_type;
else
has_non_decimal_type = nested_type;
if (has_non_decimal_type && has_decimal_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal types of arguments for function {}: {} and {}.",
getName(), has_non_decimal_type->getName(), has_decimal_type);
}
}
}
DataTypePtr result_type;
if (!nested_types.empty())
result_type = getMostSubtype(nested_types, true);
if (has_nothing)
// If any DataTypeNothing in ArrayModeIntersect or all arrays in ArrayModeUnion are DataTypeNothing
if (has_nothing || nested_types.empty())
result_type = std::make_shared<DataTypeNothing>();
else if constexpr (std::is_same_v<Mode, ArrayModeIntersect>)
result_type = getMostSubtype(nested_types, true);
else
result_type = getLeastSupertype(nested_types);
return std::make_shared<DataTypeArray>(result_type);
}
ColumnPtr FunctionArrayIntersect::castRemoveNullable(const ColumnPtr & column, const DataTypePtr & data_type) const
template <typename Mode>
ColumnPtr FunctionArrayIntersect<Mode>::castRemoveNullable(const ColumnPtr & column, const DataTypePtr & data_type) const
{
if (const auto * column_nullable = checkAndGetColumn<ColumnNullable>(column.get()))
{
@ -208,7 +252,8 @@ ColumnPtr FunctionArrayIntersect::castRemoveNullable(const ColumnPtr & column, c
return column;
}
FunctionArrayIntersect::CastArgumentsResult FunctionArrayIntersect::castColumns(
template <typename Mode>
FunctionArrayIntersect<Mode>::CastArgumentsResult FunctionArrayIntersect<Mode>::castColumns(
const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type, const DataTypePtr & return_type_with_nulls)
{
size_t num_args = arguments.size();
@ -294,7 +339,8 @@ static ColumnPtr callFunctionNotEquals(ColumnWithTypeAndName first, ColumnWithTy
return eq_func->execute(args, eq_func->getResultType(), args.front().column->size());
}
FunctionArrayIntersect::UnpackedArrays FunctionArrayIntersect::prepareArrays(
template <typename Mode>
FunctionArrayIntersect<Mode>::UnpackedArrays FunctionArrayIntersect<Mode>::prepareArrays(
const ColumnsWithTypeAndName & columns, ColumnsWithTypeAndName & initial_columns) const
{
UnpackedArrays arrays;
@ -384,7 +430,8 @@ FunctionArrayIntersect::UnpackedArrays FunctionArrayIntersect::prepareArrays(
return arrays;
}
ColumnPtr FunctionArrayIntersect::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
template <typename Mode>
ColumnPtr FunctionArrayIntersect<Mode>::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
{
const auto * return_type_array = checkAndGetDataType<DataTypeArray>(result_type.get());
@ -402,7 +449,12 @@ ColumnPtr FunctionArrayIntersect::executeImpl(const ColumnsWithTypeAndName & arg
for (size_t i = 0; i < num_args; ++i)
data_types.push_back(arguments[i].type);
auto return_type_with_nulls = getMostSubtype(data_types, true, true);
DataTypePtr return_type_with_nulls;
if constexpr (std::is_same_v<Mode, ArrayModeIntersect>)
return_type_with_nulls = getMostSubtype(data_types, true, true);
else
return_type_with_nulls = getLeastSupertype(data_types);
auto casted_columns = castColumns(arguments, result_type, return_type_with_nulls);
UnpackedArrays arrays = prepareArrays(casted_columns.casted, casted_columns.initial);
@ -450,8 +502,9 @@ ColumnPtr FunctionArrayIntersect::executeImpl(const ColumnsWithTypeAndName & arg
return result_column;
}
template <typename Mode>
template <class T>
void FunctionArrayIntersect::NumberExecutor::operator()(TypeList<T>)
void FunctionArrayIntersect<Mode>::NumberExecutor::operator()(TypeList<T>)
{
using Container = ClearableHashMapWithStackMemory<T, size_t, DefaultHash<T>,
INITIAL_SIZE_DEGREE>;
@ -460,8 +513,9 @@ void FunctionArrayIntersect::NumberExecutor::operator()(TypeList<T>)
result = execute<Container, ColumnVector<T>, true>(arrays, ColumnVector<T>::create());
}
template <typename Mode>
template <class T>
void FunctionArrayIntersect::DecimalExecutor::operator()(TypeList<T>)
void FunctionArrayIntersect<Mode>::DecimalExecutor::operator()(TypeList<T>)
{
using Container = ClearableHashMapWithStackMemory<T, size_t, DefaultHash<T>,
INITIAL_SIZE_DEGREE>;
@ -471,13 +525,15 @@ void FunctionArrayIntersect::DecimalExecutor::operator()(TypeList<T>)
result = execute<Container, ColumnDecimal<T>, true>(arrays, ColumnDecimal<T>::create(0, decimal->getScale()));
}
template <typename Mode>
template <typename Map, typename ColumnType, bool is_numeric_column>
ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, MutableColumnPtr result_data_ptr)
ColumnPtr FunctionArrayIntersect<Mode>::execute(const UnpackedArrays & arrays, MutableColumnPtr result_data_ptr)
{
auto args = arrays.args.size();
auto rows = arrays.base_rows;
bool all_nullable = true;
bool has_nullable = false;
std::vector<const ColumnType *> columns;
columns.reserve(args);
@ -493,6 +549,8 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, Mutable
if (!arg.null_map)
all_nullable = false;
else
has_nullable = true;
}
auto & result_data = static_cast<ColumnType &>(*result_data_ptr);
@ -511,6 +569,7 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, Mutable
map.clear();
bool all_has_nullable = all_nullable;
bool has_a_null = false;
bool current_has_nullable = false;
for (size_t arg_num = 0; arg_num < args; ++arg_num)
@ -546,7 +605,7 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, Mutable
}
/// Here we count the number of element appearances, but no more than once per array.
if (*value == arg_num)
if (*value <= arg_num)
++(*value);
}
}
@ -561,77 +620,90 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, Mutable
}
if (!current_has_nullable)
all_has_nullable = false;
else
has_a_null = true;
}
// We have NULL in output only once if it should be there
bool null_added = false;
const auto & arg = arrays.args[0];
size_t off;
// const array has only one row
if (arg.is_const)
off = (*arg.offsets)[0];
else
off = (*arg.offsets)[row];
bool use_null_map;
for (auto i : collections::range(prev_off[0], off))
if constexpr (std::is_same_v<Mode, ArrayModeUnion>)
{
all_has_nullable = all_nullable;
typename Map::LookupResult pair = nullptr;
if (arg.null_map && (*arg.null_map)[i])
use_null_map = has_nullable;
for (auto & p : map)
{
current_has_nullable = true;
if (all_has_nullable && !null_added)
typename Map::LookupResult pair = map.find(p.getKey());
if (pair && pair->getMapped() >= 1)
{
++result_offset;
result_data.insertDefault();
null_map.push_back(1);
null_added = true;
insertElement<Map, ColumnType, is_numeric_column>(pair, result_offset, result_data, null_map, use_null_map);
}
if (null_added)
continue;
}
else if constexpr (is_numeric_column)
if (has_a_null && !null_added)
{
pair = map.find(columns[0]->getElement(i));
}
else if constexpr (std::is_same_v<ColumnType, ColumnString> || std::is_same_v<ColumnType, ColumnFixedString>)
pair = map.find(columns[0]->getDataAt(i));
else
{
const char * data = nullptr;
pair = map.find(columns[0]->serializeValueIntoArena(i, arena, data));
}
prev_off[0] = off;
if (arg.is_const)
prev_off[0] = 0;
if (!current_has_nullable)
all_has_nullable = false;
if (pair && pair->getMapped() == args)
{
// We increase pair->getMapped() here to not skip duplicate values from the first array.
++pair->getMapped();
++result_offset;
if constexpr (is_numeric_column)
{
result_data.insertValue(pair->getKey());
}
else if constexpr (std::is_same_v<ColumnType, ColumnString> || std::is_same_v<ColumnType, ColumnFixedString>)
{
result_data.insertData(pair->getKey().data, pair->getKey().size);
}
else
{
std::ignore = result_data.deserializeAndInsertFromArena(pair->getKey().data);
}
if (all_nullable)
null_map.push_back(0);
result_data.insertDefault();
null_map.push_back(1);
null_added = true;
}
}
result_offsets.getElement(row) = result_offset;
else if constexpr (std::is_same_v<Mode, ArrayModeIntersect>)
{
use_null_map = all_nullable;
const auto & arg = arrays.args[0];
size_t off;
// const array has only one row
if (arg.is_const)
off = (*arg.offsets)[0];
else
off = (*arg.offsets)[row];
for (auto i : collections::range(prev_off[0], off))
{
all_has_nullable = all_nullable;
typename Map::LookupResult pair = nullptr;
if (arg.null_map && (*arg.null_map)[i])
{
current_has_nullable = true;
if (all_has_nullable && !null_added)
{
++result_offset;
result_data.insertDefault();
null_map.push_back(1);
null_added = true;
}
if (null_added)
continue;
}
else if constexpr (is_numeric_column)
{
pair = map.find(columns[0]->getElement(i));
}
else if constexpr (std::is_same_v<ColumnType, ColumnString> || std::is_same_v<ColumnType, ColumnFixedString>)
pair = map.find(columns[0]->getDataAt(i));
else
{
const char * data = nullptr;
pair = map.find(columns[0]->serializeValueIntoArena(i, arena, data));
}
prev_off[0] = off;
if (arg.is_const)
prev_off[0] = 0;
if (!current_has_nullable)
all_has_nullable = false;
// Add the value if all arrays have the value for intersect
// or if there was at least one occurrence in all of the arrays for union
if (pair && pair->getMapped() == args)
{
insertElement<Map, ColumnType, is_numeric_column>(pair, result_offset, result_data, null_map, use_null_map);
}
}
}
result_offsets.getElement(row) = result_offset;
}
ColumnPtr result_column = std::move(result_data_ptr);
if (all_nullable)
@ -640,10 +712,36 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, Mutable
}
template <typename Mode>
template <typename Map, typename ColumnType, bool is_numeric_column>
void FunctionArrayIntersect<Mode>::insertElement(typename Map::LookupResult & pair, size_t & result_offset, ColumnType & result_data, NullMap & null_map, const bool & use_null_map)
{
pair->getMapped() = -1;
++result_offset;
if constexpr (is_numeric_column)
{
result_data.insertValue(pair->getKey());
}
else if constexpr (std::is_same_v<ColumnType, ColumnString> || std::is_same_v<ColumnType, ColumnFixedString>)
{
result_data.insertData(pair->getKey().data, pair->getKey().size);
}
else
{
std::ignore = result_data.deserializeAndInsertFromArena(pair->getKey().data);
}
if (use_null_map)
null_map.push_back(0);
}
using ArrayIntersect = FunctionArrayIntersect<ArrayModeIntersect>;
using ArrayUnion = FunctionArrayIntersect<ArrayModeUnion>;
REGISTER_FUNCTION(ArrayIntersect)
{
factory.registerFunction<FunctionArrayIntersect>();
factory.registerFunction<ArrayIntersect>();
factory.registerFunction<ArrayUnion>();
}
}

View File

@ -141,6 +141,7 @@ arraySort
arraySplit
arrayStringConcat
arraySum
arrayUnion
arrayUniq
arrayWithConstant
asinh

View File

@ -0,0 +1,43 @@
[1,2]
[1,2]
[1,2]
[1,2,3]
-------
[]
[1]
[1,2]
[1,2,3]
-------
[]
[1]
[1,2]
[1,2,3]
-------
[1,2]
[1,2]
[1,2]
[1,2,3]
-------
[1,2,3,4]
[1,2,3,4]
[1,2,3,4]
[1,2,3,4]
-------
[]
[]
[]
[]
-------
[-100,156]
-------
[-257,-100,1]
-------
['hello','hi']
-------
[1,2,3,NULL]
-------
[1,2,3,NULL]
-------
[1,2,3,4,5,10,20]
-------
[1,2,3]

View File

@ -0,0 +1,38 @@
drop table if exists array_union;
create table array_union (date Date, arr Array(UInt8)) engine=MergeTree partition by date order by date;
insert into array_union values ('2019-01-01', [1,2,3]);
insert into array_union values ('2019-01-01', [1,2]);
insert into array_union values ('2019-01-01', [1]);
insert into array_union values ('2019-01-01', []);
select arraySort(arrayUnion(arr, [1,2])) from array_union order by arr;
select '-------';
select arraySort(arrayUnion(arr, [])) from array_union order by arr;
select '-------';
select arraySort(arrayUnion([], arr)) from array_union order by arr;
select '-------';
select arraySort(arrayUnion([1,2], arr)) from array_union order by arr;
select '-------';
select arraySort(arrayUnion([1,2], [1,2,3,4])) from array_union order by arr;
select '-------';
select arraySort(arrayUnion([], [])) from array_union order by arr;
drop table if exists array_union;
select '-------';
select arraySort(arrayUnion([-100], [156]));
select '-------';
select arraySort(arrayUnion([1], [-257, -100]));
select '-------';
select arraySort(arrayUnion(['hi'], ['hello', 'hi'], []));
select '-------';
SELECT arraySort(arrayUnion([1, 2, NULL], [1, 3, NULL], [2, 3, NULL]));
select '-------';
SELECT arraySort(arrayUnion([NULL, NULL, NULL, 1], [1, NULL, NULL], [1, 2, 3, NULL]));
select '-------';
SELECT arraySort(arrayUnion([1, 1, 1, 2, 3], [2, 2, 4], [5, 10, 20]));
select '-------';
SELECT arraySort(arrayUnion([1, 2], [1, 3], [])),

View File

@ -1,4 +1,4 @@
personal_ws-1.1 en 2983
personal_ws-1.1 en 2984
AArch
ACLs
ALTERs
@ -1213,6 +1213,7 @@ arraySort
arraySplit
arrayStringConcat
arraySum
arrayUnion
arrayUniq
arrayWithConstant
arrayZip