From d29f0d4c9668e134d1569f16089b77ea1a62e070 Mon Sep 17 00:00:00 2001 From: Peter Nguyen Date: Tue, 27 Aug 2024 21:29:34 -0600 Subject: [PATCH] Pull out code into insertElement() function and implement arrayUnion logic --- src/Functions/array/arrayIntersect.cpp | 169 ++++++++++++++++--------- 1 file changed, 111 insertions(+), 58 deletions(-) diff --git a/src/Functions/array/arrayIntersect.cpp b/src/Functions/array/arrayIntersect.cpp index 8affe1ac11c..53814e8b8da 100644 --- a/src/Functions/array/arrayIntersect.cpp +++ b/src/Functions/array/arrayIntersect.cpp @@ -108,6 +108,9 @@ private: template static ColumnPtr execute(const UnpackedArrays & arrays, MutableColumnPtr result_data); + template + static void insertElement(typename Map::LookupResult & pair, size_t & result_offset, ColumnType & result_data, NullMap & null_map, const bool & use_null_map); + struct NumberExecutor { const UnpackedArrays & arrays; @@ -158,6 +161,12 @@ DataTypePtr FunctionArrayIntersect::getReturnTypeImpl(const DataTypes & ar if (typeid_cast(nested_type.get())) has_nothing = true; + // { + // if (std::is_same_v) { + // has_nothing = true; + // break; + // } + // } else nested_types.push_back(nested_type); } @@ -169,6 +178,11 @@ DataTypePtr FunctionArrayIntersect::getReturnTypeImpl(const DataTypes & ar if (has_nothing) result_type = std::make_shared(); + // // If found any DataTypeNothing in IntersectMode or all DattaTypeNothing in UnionMode + // if (has_nothing || nested_types.empty()) + // result_type = std::make_shared(); + // else + // result_type = getMostSubtype(nested_types, true); return std::make_shared(result_type); } @@ -529,6 +543,7 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, M map.clear(); bool all_has_nullable = all_nullable; + bool has_a_null = false; bool current_has_nullable = false; for (size_t arg_num = 0; arg_num < args; ++arg_num) @@ -564,7 +579,7 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, M } /// Here we count the number of element appearances, but no more than once per array. - if (*value == arg_num) + if (*value <= arg_num) ++(*value); } } @@ -579,77 +594,93 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, M } if (!current_has_nullable) all_has_nullable = false; + else + has_a_null = true; } // We have NULL in output only once if it should be there bool null_added = false; - const auto & arg = arrays.args[0]; - size_t off; - // const array has only one row - if (arg.is_const) - off = (*arg.offsets)[0]; - else - off = (*arg.offsets)[row]; + bool use_null_map; - for (auto i : collections::range(prev_off[0], off)) + if (std::is_same_v) { - all_has_nullable = all_nullable; - typename Map::LookupResult pair = nullptr; - - if (arg.null_map && (*arg.null_map)[i]) + use_null_map = has_a_null; + for (auto & p : map) { - current_has_nullable = true; - if (all_has_nullable && !null_added) + typename Map::LookupResult pair = map.find(p.getKey()); + if (pair && pair->getMapped() >= 1) { - ++result_offset; - result_data.insertDefault(); - null_map.push_back(1); - null_added = true; + insertElement(pair, result_offset, result_data, null_map, use_null_map); } - if (null_added) - continue; } - else if constexpr (is_numeric_column) + if (has_a_null && !null_added) { - pair = map.find(columns[0]->getElement(i)); - } - else if constexpr (std::is_same_v || std::is_same_v) - pair = map.find(columns[0]->getDataAt(i)); - else - { - const char * data = nullptr; - pair = map.find(columns[0]->serializeValueIntoArena(i, arena, data)); - } - prev_off[0] = off; - if (arg.is_const) - prev_off[0] = 0; - - if (!current_has_nullable) - all_has_nullable = false; - - if (pair && pair->getMapped() == args) - { - // We increase pair->getMapped() here to not skip duplicate values from the first array. - ++pair->getMapped(); ++result_offset; - if constexpr (is_numeric_column) - { - result_data.insertValue(pair->getKey()); - } - else if constexpr (std::is_same_v || std::is_same_v) - { - result_data.insertData(pair->getKey().data, pair->getKey().size); - } - else - { - std::ignore = result_data.deserializeAndInsertFromArena(pair->getKey().data); - } - if (all_nullable) - null_map.push_back(0); + result_data.insertDefault(); + null_map.push_back(1); + null_added = true; } } - result_offsets.getElement(row) = result_offset; + else if (std::is_same_v) + { + use_null_map = all_nullable; + const auto & arg = arrays.args[0]; + size_t off; + // const array has only one row + if (arg.is_const) + off = (*arg.offsets)[0]; + else + off = (*arg.offsets)[row]; + for (auto i : collections::range(prev_off[0], off)) + { + all_has_nullable = all_nullable; + typename Map::LookupResult pair = nullptr; + + if (arg.null_map && (*arg.null_map)[i]) + { + current_has_nullable = true; + if (all_has_nullable && !null_added) + { + ++result_offset; + result_data.insertDefault(); + null_map.push_back(1); + null_added = true; + } + if (null_added) + continue; + } + else if constexpr (is_numeric_column) + { + pair = map.find(columns[0]->getElement(i)); + } + else if constexpr (std::is_same_v || std::is_same_v) + pair = map.find(columns[0]->getDataAt(i)); + else + { + const char * data = nullptr; + pair = map.find(columns[0]->serializeValueIntoArena(i, arena, data)); + } + prev_off[0] = off; + if (arg.is_const) + prev_off[0] = 0; + + if (!current_has_nullable) + all_has_nullable = false; + + // Add the value if all arrays have the value for intersect + // or if there was at least one occurence in all of the arrays + if (pair && pair->getMapped() == args) + { + insertElement(pair, result_offset, result_data, null_map, use_null_map); + } + } + } + else + { + + } + result_offsets.getElement(row) = result_offset; } ColumnPtr result_column = std::move(result_data_ptr); if (all_nullable) @@ -658,14 +689,36 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, M } +template +template +void FunctionArrayIntersect::insertElement(typename Map::LookupResult & pair, size_t & result_offset, ColumnType & result_data, NullMap & null_map, const bool & use_null_map) +{ + pair->getMapped() = -1; + ++result_offset; + if constexpr (is_numeric_column) + { + result_data.insertValue(pair->getKey()); + } + else if constexpr (std::is_same_v || std::is_same_v) + { + result_data.insertData(pair->getKey().data, pair->getKey().size); + } + else + { + std::ignore = result_data.deserializeAndInsertFromArena(pair->getKey().data); + } + if (use_null_map) + null_map.push_back(0); +} + using ArrayIntersect = FunctionArrayIntersect; using ArrayUnion = FunctionArrayIntersect; REGISTER_FUNCTION(ArrayIntersect) { - // factory.registerFunction(); factory.registerFunction(); + factory.registerFunction(); } }