diff --git a/docs/en/sql-reference/functions/array-functions.md b/docs/en/sql-reference/functions/array-functions.md index ad971ae7554..7b2b93fba3f 100644 --- a/docs/en/sql-reference/functions/array-functions.md +++ b/docs/en/sql-reference/functions/array-functions.md @@ -1717,6 +1717,24 @@ Result: [[1,1,2,3],[1,2,3,4]] ``` +## arrayUnion(arr) + +Takes multiple arrays, returns an array that contains all elements that are present in any of the source arrays. + +Example: +```sql +SELECT + arrayUnion([-2, 1], [10, 1], [-2], []) as num_example, + arrayUnion(['hi'], [], ['hello', 'hi']) as str_example, + arrayUnion([1, 3, NULL], [2, 3, NULL]) as null_example +``` + +```text +┌─num_example─┬─str_example────┬─null_example─┐ +│ [10,-2,1] │ ['hello','hi'] │ [3,2,1,NULL] │ +└─────────────┴────────────────┴──────────────┘ +``` + ## arrayIntersect(arr) Takes multiple arrays, returns an array with elements that are present in all source arrays. diff --git a/src/Functions/array/arrayIntersect.cpp b/src/Functions/array/arrayIntersect.cpp index 209441eb301..316cc869ca1 100644 --- a/src/Functions/array/arrayIntersect.cpp +++ b/src/Functions/array/arrayIntersect.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -12,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -35,10 +37,21 @@ namespace ErrorCodes extern const int ILLEGAL_TYPE_OF_ARGUMENT; } +struct ArrayModeIntersect +{ + static constexpr auto name = "arrayIntersect"; +}; + +struct ArrayModeUnion +{ + static constexpr auto name = "arrayUnion"; +}; + +template class FunctionArrayIntersect : public IFunction { public: - static constexpr auto name = "arrayIntersect"; + static constexpr auto name = Mode::name; static FunctionPtr create(ContextPtr context) { return std::make_shared(context); } explicit FunctionArrayIntersect(ContextPtr context_) : context(context_) {} @@ -97,6 +110,9 @@ private: template static ColumnPtr execute(const UnpackedArrays & arrays, MutableColumnPtr result_data); + template + static void insertElement(typename Map::LookupResult & pair, size_t & result_offset, ColumnType & result_data, NullMap & null_map, const bool & use_null_map); + struct NumberExecutor { const UnpackedArrays & arrays; @@ -124,13 +140,15 @@ private: }; }; - -DataTypePtr FunctionArrayIntersect::getReturnTypeImpl(const DataTypes & arguments) const +template +DataTypePtr FunctionArrayIntersect::getReturnTypeImpl(const DataTypes & arguments) const { DataTypes nested_types; nested_types.reserve(arguments.size()); bool has_nothing = false; + DataTypePtr has_decimal_type = nullptr; + DataTypePtr has_non_decimal_type = nullptr; if (arguments.empty()) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires at least one argument.", getName()); @@ -146,23 +164,49 @@ DataTypePtr FunctionArrayIntersect::getReturnTypeImpl(const DataTypes & argument const auto & nested_type = array_type->getNestedType(); if (typeid_cast(nested_type.get())) - has_nothing = true; + { + if constexpr (std::is_same_v) + { + has_nothing = true; + break; + } + } else + { nested_types.push_back(nested_type); + + /// Throw exception if have a decimal and another type (e.g int/date type) + /// This is the same behavior as the arrayIntersect and notEquals functions + /// This case is not covered by getLeastSupertype() and results in crashing the program if left out + if constexpr (std::is_same_v) + { + if (WhichDataType(nested_type).isDecimal()) + has_decimal_type = nested_type; + else + has_non_decimal_type = nested_type; + + if (has_non_decimal_type && has_decimal_type) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal types of arguments for function {}: {} and {}.", + getName(), has_non_decimal_type->getName(), has_decimal_type); + } + } } DataTypePtr result_type; - if (!nested_types.empty()) - result_type = getMostSubtype(nested_types, true); - - if (has_nothing) + // If any DataTypeNothing in ArrayModeIntersect or all arrays in ArrayModeUnion are DataTypeNothing + if (has_nothing || nested_types.empty()) result_type = std::make_shared(); + else if constexpr (std::is_same_v) + result_type = getMostSubtype(nested_types, true); + else + result_type = getLeastSupertype(nested_types); return std::make_shared(result_type); } -ColumnPtr FunctionArrayIntersect::castRemoveNullable(const ColumnPtr & column, const DataTypePtr & data_type) const +template +ColumnPtr FunctionArrayIntersect::castRemoveNullable(const ColumnPtr & column, const DataTypePtr & data_type) const { if (const auto * column_nullable = checkAndGetColumn(column.get())) { @@ -208,7 +252,8 @@ ColumnPtr FunctionArrayIntersect::castRemoveNullable(const ColumnPtr & column, c return column; } -FunctionArrayIntersect::CastArgumentsResult FunctionArrayIntersect::castColumns( +template +FunctionArrayIntersect::CastArgumentsResult FunctionArrayIntersect::castColumns( const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type, const DataTypePtr & return_type_with_nulls) { size_t num_args = arguments.size(); @@ -294,7 +339,8 @@ static ColumnPtr callFunctionNotEquals(ColumnWithTypeAndName first, ColumnWithTy return eq_func->execute(args, eq_func->getResultType(), args.front().column->size()); } -FunctionArrayIntersect::UnpackedArrays FunctionArrayIntersect::prepareArrays( +template +FunctionArrayIntersect::UnpackedArrays FunctionArrayIntersect::prepareArrays( const ColumnsWithTypeAndName & columns, ColumnsWithTypeAndName & initial_columns) const { UnpackedArrays arrays; @@ -384,7 +430,8 @@ FunctionArrayIntersect::UnpackedArrays FunctionArrayIntersect::prepareArrays( return arrays; } -ColumnPtr FunctionArrayIntersect::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const +template +ColumnPtr FunctionArrayIntersect::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const { const auto * return_type_array = checkAndGetDataType(result_type.get()); @@ -402,7 +449,12 @@ ColumnPtr FunctionArrayIntersect::executeImpl(const ColumnsWithTypeAndName & arg for (size_t i = 0; i < num_args; ++i) data_types.push_back(arguments[i].type); - auto return_type_with_nulls = getMostSubtype(data_types, true, true); + DataTypePtr return_type_with_nulls; + if constexpr (std::is_same_v) + return_type_with_nulls = getMostSubtype(data_types, true, true); + else + return_type_with_nulls = getLeastSupertype(data_types); + auto casted_columns = castColumns(arguments, result_type, return_type_with_nulls); UnpackedArrays arrays = prepareArrays(casted_columns.casted, casted_columns.initial); @@ -450,8 +502,9 @@ ColumnPtr FunctionArrayIntersect::executeImpl(const ColumnsWithTypeAndName & arg return result_column; } +template template -void FunctionArrayIntersect::NumberExecutor::operator()(TypeList) +void FunctionArrayIntersect::NumberExecutor::operator()(TypeList) { using Container = ClearableHashMapWithStackMemory, INITIAL_SIZE_DEGREE>; @@ -460,8 +513,9 @@ void FunctionArrayIntersect::NumberExecutor::operator()(TypeList) result = execute, true>(arrays, ColumnVector::create()); } +template template -void FunctionArrayIntersect::DecimalExecutor::operator()(TypeList) +void FunctionArrayIntersect::DecimalExecutor::operator()(TypeList) { using Container = ClearableHashMapWithStackMemory, INITIAL_SIZE_DEGREE>; @@ -471,13 +525,15 @@ void FunctionArrayIntersect::DecimalExecutor::operator()(TypeList) result = execute, true>(arrays, ColumnDecimal::create(0, decimal->getScale())); } +template template -ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, MutableColumnPtr result_data_ptr) +ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, MutableColumnPtr result_data_ptr) { auto args = arrays.args.size(); auto rows = arrays.base_rows; bool all_nullable = true; + bool has_nullable = false; std::vector columns; columns.reserve(args); @@ -493,6 +549,8 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, Mutable if (!arg.null_map) all_nullable = false; + else + has_nullable = true; } auto & result_data = static_cast(*result_data_ptr); @@ -511,6 +569,7 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, Mutable map.clear(); bool all_has_nullable = all_nullable; + bool has_a_null = false; bool current_has_nullable = false; for (size_t arg_num = 0; arg_num < args; ++arg_num) @@ -546,7 +605,7 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, Mutable } /// Here we count the number of element appearances, but no more than once per array. - if (*value == arg_num) + if (*value <= arg_num) ++(*value); } } @@ -561,77 +620,90 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, Mutable } if (!current_has_nullable) all_has_nullable = false; + else + has_a_null = true; } // We have NULL in output only once if it should be there bool null_added = false; - const auto & arg = arrays.args[0]; - size_t off; - // const array has only one row - if (arg.is_const) - off = (*arg.offsets)[0]; - else - off = (*arg.offsets)[row]; + bool use_null_map; - for (auto i : collections::range(prev_off[0], off)) + if constexpr (std::is_same_v) { - all_has_nullable = all_nullable; - typename Map::LookupResult pair = nullptr; - - if (arg.null_map && (*arg.null_map)[i]) + use_null_map = has_nullable; + for (auto & p : map) { - current_has_nullable = true; - if (all_has_nullable && !null_added) + typename Map::LookupResult pair = map.find(p.getKey()); + if (pair && pair->getMapped() >= 1) { - ++result_offset; - result_data.insertDefault(); - null_map.push_back(1); - null_added = true; + insertElement(pair, result_offset, result_data, null_map, use_null_map); } - if (null_added) - continue; } - else if constexpr (is_numeric_column) + if (has_a_null && !null_added) { - pair = map.find(columns[0]->getElement(i)); - } - else if constexpr (std::is_same_v || std::is_same_v) - pair = map.find(columns[0]->getDataAt(i)); - else - { - const char * data = nullptr; - pair = map.find(columns[0]->serializeValueIntoArena(i, arena, data)); - } - prev_off[0] = off; - if (arg.is_const) - prev_off[0] = 0; - - if (!current_has_nullable) - all_has_nullable = false; - - if (pair && pair->getMapped() == args) - { - // We increase pair->getMapped() here to not skip duplicate values from the first array. - ++pair->getMapped(); ++result_offset; - if constexpr (is_numeric_column) - { - result_data.insertValue(pair->getKey()); - } - else if constexpr (std::is_same_v || std::is_same_v) - { - result_data.insertData(pair->getKey().data, pair->getKey().size); - } - else - { - std::ignore = result_data.deserializeAndInsertFromArena(pair->getKey().data); - } - if (all_nullable) - null_map.push_back(0); + result_data.insertDefault(); + null_map.push_back(1); + null_added = true; } } - result_offsets.getElement(row) = result_offset; + else if constexpr (std::is_same_v) + { + use_null_map = all_nullable; + const auto & arg = arrays.args[0]; + size_t off; + // const array has only one row + if (arg.is_const) + off = (*arg.offsets)[0]; + else + off = (*arg.offsets)[row]; + for (auto i : collections::range(prev_off[0], off)) + { + all_has_nullable = all_nullable; + typename Map::LookupResult pair = nullptr; + + if (arg.null_map && (*arg.null_map)[i]) + { + current_has_nullable = true; + if (all_has_nullable && !null_added) + { + ++result_offset; + result_data.insertDefault(); + null_map.push_back(1); + null_added = true; + } + if (null_added) + continue; + } + else if constexpr (is_numeric_column) + { + pair = map.find(columns[0]->getElement(i)); + } + else if constexpr (std::is_same_v || std::is_same_v) + pair = map.find(columns[0]->getDataAt(i)); + else + { + const char * data = nullptr; + pair = map.find(columns[0]->serializeValueIntoArena(i, arena, data)); + } + prev_off[0] = off; + if (arg.is_const) + prev_off[0] = 0; + + if (!current_has_nullable) + all_has_nullable = false; + + // Add the value if all arrays have the value for intersect + // or if there was at least one occurrence in all of the arrays for union + if (pair && pair->getMapped() == args) + { + insertElement(pair, result_offset, result_data, null_map, use_null_map); + } + } + } + + result_offsets.getElement(row) = result_offset; } ColumnPtr result_column = std::move(result_data_ptr); if (all_nullable) @@ -640,10 +712,36 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, Mutable } +template +template +void FunctionArrayIntersect::insertElement(typename Map::LookupResult & pair, size_t & result_offset, ColumnType & result_data, NullMap & null_map, const bool & use_null_map) +{ + pair->getMapped() = -1; + ++result_offset; + if constexpr (is_numeric_column) + { + result_data.insertValue(pair->getKey()); + } + else if constexpr (std::is_same_v || std::is_same_v) + { + result_data.insertData(pair->getKey().data, pair->getKey().size); + } + else + { + std::ignore = result_data.deserializeAndInsertFromArena(pair->getKey().data); + } + if (use_null_map) + null_map.push_back(0); +} + + +using ArrayIntersect = FunctionArrayIntersect; +using ArrayUnion = FunctionArrayIntersect; REGISTER_FUNCTION(ArrayIntersect) { - factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); } } diff --git a/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference b/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference index 533389a40f6..ec5bdbb54b5 100644 --- a/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference +++ b/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference @@ -141,6 +141,7 @@ arraySort arraySplit arrayStringConcat arraySum +arrayUnion arrayUniq arrayWithConstant asinh diff --git a/tests/queries/0_stateless/03224_arrayUnion.reference b/tests/queries/0_stateless/03224_arrayUnion.reference new file mode 100644 index 00000000000..b900b6cdb0a --- /dev/null +++ b/tests/queries/0_stateless/03224_arrayUnion.reference @@ -0,0 +1,43 @@ +[1,2] +[1,2] +[1,2] +[1,2,3] +------- +[] +[1] +[1,2] +[1,2,3] +------- +[] +[1] +[1,2] +[1,2,3] +------- +[1,2] +[1,2] +[1,2] +[1,2,3] +------- +[1,2,3,4] +[1,2,3,4] +[1,2,3,4] +[1,2,3,4] +------- +[] +[] +[] +[] +------- +[-100,156] +------- +[-257,-100,1] +------- +['hello','hi'] +------- +[1,2,3,NULL] +------- +[1,2,3,NULL] +------- +[1,2,3,4,5,10,20] +------- +[1,2,3] diff --git a/tests/queries/0_stateless/03224_arrayUnion.sql b/tests/queries/0_stateless/03224_arrayUnion.sql new file mode 100644 index 00000000000..dedbacad906 --- /dev/null +++ b/tests/queries/0_stateless/03224_arrayUnion.sql @@ -0,0 +1,38 @@ +drop table if exists array_union; + +create table array_union (date Date, arr Array(UInt8)) engine=MergeTree partition by date order by date; + +insert into array_union values ('2019-01-01', [1,2,3]); +insert into array_union values ('2019-01-01', [1,2]); +insert into array_union values ('2019-01-01', [1]); +insert into array_union values ('2019-01-01', []); + + +select arraySort(arrayUnion(arr, [1,2])) from array_union order by arr; +select '-------'; +select arraySort(arrayUnion(arr, [])) from array_union order by arr; +select '-------'; +select arraySort(arrayUnion([], arr)) from array_union order by arr; +select '-------'; +select arraySort(arrayUnion([1,2], arr)) from array_union order by arr; +select '-------'; +select arraySort(arrayUnion([1,2], [1,2,3,4])) from array_union order by arr; +select '-------'; +select arraySort(arrayUnion([], [])) from array_union order by arr; + +drop table if exists array_union; + +select '-------'; +select arraySort(arrayUnion([-100], [156])); +select '-------'; +select arraySort(arrayUnion([1], [-257, -100])); +select '-------'; +select arraySort(arrayUnion(['hi'], ['hello', 'hi'], [])); +select '-------'; +SELECT arraySort(arrayUnion([1, 2, NULL], [1, 3, NULL], [2, 3, NULL])); +select '-------'; +SELECT arraySort(arrayUnion([NULL, NULL, NULL, 1], [1, NULL, NULL], [1, 2, 3, NULL])); +select '-------'; +SELECT arraySort(arrayUnion([1, 1, 1, 2, 3], [2, 2, 4], [5, 10, 20])); +select '-------'; +SELECT arraySort(arrayUnion([1, 2], [1, 3], [])), diff --git a/utils/check-style/aspell-ignore/en/aspell-dict.txt b/utils/check-style/aspell-ignore/en/aspell-dict.txt index 3467f21c812..85d75930850 100644 --- a/utils/check-style/aspell-ignore/en/aspell-dict.txt +++ b/utils/check-style/aspell-ignore/en/aspell-dict.txt @@ -1,4 +1,4 @@ -personal_ws-1.1 en 2983 +personal_ws-1.1 en 2984 AArch ACLs ALTERs @@ -1213,6 +1213,7 @@ arraySort arraySplit arrayStringConcat arraySum +arrayUnion arrayUniq arrayWithConstant arrayZip