Pull out code into insertElement() function and implement arrayUnion logic

This commit is contained in:
Peter Nguyen 2024-08-27 21:29:34 -06:00
parent 1dcfaa91c2
commit d29f0d4c96

View File

@ -108,6 +108,9 @@ private:
template <typename Map, typename ColumnType, bool is_numeric_column>
static ColumnPtr execute(const UnpackedArrays & arrays, MutableColumnPtr result_data);
template <typename Map, typename ColumnType, bool is_numeric_column>
static void insertElement(typename Map::LookupResult & pair, size_t & result_offset, ColumnType & result_data, NullMap & null_map, const bool & use_null_map);
struct NumberExecutor
{
const UnpackedArrays & arrays;
@ -158,6 +161,12 @@ DataTypePtr FunctionArrayIntersect<Mode>::getReturnTypeImpl(const DataTypes & ar
if (typeid_cast<const DataTypeNothing *>(nested_type.get()))
has_nothing = true;
// {
// if (std::is_same_v<Mode, ArrayModeIntersect>) {
// has_nothing = true;
// break;
// }
// }
else
nested_types.push_back(nested_type);
}
@ -169,6 +178,11 @@ DataTypePtr FunctionArrayIntersect<Mode>::getReturnTypeImpl(const DataTypes & ar
if (has_nothing)
result_type = std::make_shared<DataTypeNothing>();
// // If found any DataTypeNothing in IntersectMode or all DattaTypeNothing in UnionMode
// if (has_nothing || nested_types.empty())
// result_type = std::make_shared<DataTypeNothing>();
// else
// result_type = getMostSubtype(nested_types, true);
return std::make_shared<DataTypeArray>(result_type);
}
@ -529,6 +543,7 @@ ColumnPtr FunctionArrayIntersect<Mode>::execute(const UnpackedArrays & arrays, M
map.clear();
bool all_has_nullable = all_nullable;
bool has_a_null = false;
bool current_has_nullable = false;
for (size_t arg_num = 0; arg_num < args; ++arg_num)
@ -564,7 +579,7 @@ ColumnPtr FunctionArrayIntersect<Mode>::execute(const UnpackedArrays & arrays, M
}
/// Here we count the number of element appearances, but no more than once per array.
if (*value == arg_num)
if (*value <= arg_num)
++(*value);
}
}
@ -579,77 +594,93 @@ ColumnPtr FunctionArrayIntersect<Mode>::execute(const UnpackedArrays & arrays, M
}
if (!current_has_nullable)
all_has_nullable = false;
else
has_a_null = true;
}
// We have NULL in output only once if it should be there
bool null_added = false;
const auto & arg = arrays.args[0];
size_t off;
// const array has only one row
if (arg.is_const)
off = (*arg.offsets)[0];
else
off = (*arg.offsets)[row];
bool use_null_map;
for (auto i : collections::range(prev_off[0], off))
if (std::is_same_v<Mode, ArrayModeUnion>)
{
all_has_nullable = all_nullable;
typename Map::LookupResult pair = nullptr;
if (arg.null_map && (*arg.null_map)[i])
use_null_map = has_a_null;
for (auto & p : map)
{
current_has_nullable = true;
if (all_has_nullable && !null_added)
typename Map::LookupResult pair = map.find(p.getKey());
if (pair && pair->getMapped() >= 1)
{
++result_offset;
result_data.insertDefault();
null_map.push_back(1);
null_added = true;
insertElement<Map, ColumnType, is_numeric_column>(pair, result_offset, result_data, null_map, use_null_map);
}
if (null_added)
continue;
}
else if constexpr (is_numeric_column)
if (has_a_null && !null_added)
{
pair = map.find(columns[0]->getElement(i));
}
else if constexpr (std::is_same_v<ColumnType, ColumnString> || std::is_same_v<ColumnType, ColumnFixedString>)
pair = map.find(columns[0]->getDataAt(i));
else
{
const char * data = nullptr;
pair = map.find(columns[0]->serializeValueIntoArena(i, arena, data));
}
prev_off[0] = off;
if (arg.is_const)
prev_off[0] = 0;
if (!current_has_nullable)
all_has_nullable = false;
if (pair && pair->getMapped() == args)
{
// We increase pair->getMapped() here to not skip duplicate values from the first array.
++pair->getMapped();
++result_offset;
if constexpr (is_numeric_column)
{
result_data.insertValue(pair->getKey());
}
else if constexpr (std::is_same_v<ColumnType, ColumnString> || std::is_same_v<ColumnType, ColumnFixedString>)
{
result_data.insertData(pair->getKey().data, pair->getKey().size);
}
else
{
std::ignore = result_data.deserializeAndInsertFromArena(pair->getKey().data);
}
if (all_nullable)
null_map.push_back(0);
result_data.insertDefault();
null_map.push_back(1);
null_added = true;
}
}
result_offsets.getElement(row) = result_offset;
else if (std::is_same_v<Mode, ArrayModeIntersect>)
{
use_null_map = all_nullable;
const auto & arg = arrays.args[0];
size_t off;
// const array has only one row
if (arg.is_const)
off = (*arg.offsets)[0];
else
off = (*arg.offsets)[row];
for (auto i : collections::range(prev_off[0], off))
{
all_has_nullable = all_nullable;
typename Map::LookupResult pair = nullptr;
if (arg.null_map && (*arg.null_map)[i])
{
current_has_nullable = true;
if (all_has_nullable && !null_added)
{
++result_offset;
result_data.insertDefault();
null_map.push_back(1);
null_added = true;
}
if (null_added)
continue;
}
else if constexpr (is_numeric_column)
{
pair = map.find(columns[0]->getElement(i));
}
else if constexpr (std::is_same_v<ColumnType, ColumnString> || std::is_same_v<ColumnType, ColumnFixedString>)
pair = map.find(columns[0]->getDataAt(i));
else
{
const char * data = nullptr;
pair = map.find(columns[0]->serializeValueIntoArena(i, arena, data));
}
prev_off[0] = off;
if (arg.is_const)
prev_off[0] = 0;
if (!current_has_nullable)
all_has_nullable = false;
// Add the value if all arrays have the value for intersect
// or if there was at least one occurence in all of the arrays
if (pair && pair->getMapped() == args)
{
insertElement<Map, ColumnType, is_numeric_column>(pair, result_offset, result_data, null_map, use_null_map);
}
}
}
else
{
}
result_offsets.getElement(row) = result_offset;
}
ColumnPtr result_column = std::move(result_data_ptr);
if (all_nullable)
@ -658,14 +689,36 @@ ColumnPtr FunctionArrayIntersect<Mode>::execute(const UnpackedArrays & arrays, M
}
template <typename Mode>
template <typename Map, typename ColumnType, bool is_numeric_column>
void FunctionArrayIntersect<Mode>::insertElement(typename Map::LookupResult & pair, size_t & result_offset, ColumnType & result_data, NullMap & null_map, const bool & use_null_map)
{
pair->getMapped() = -1;
++result_offset;
if constexpr (is_numeric_column)
{
result_data.insertValue(pair->getKey());
}
else if constexpr (std::is_same_v<ColumnType, ColumnString> || std::is_same_v<ColumnType, ColumnFixedString>)
{
result_data.insertData(pair->getKey().data, pair->getKey().size);
}
else
{
std::ignore = result_data.deserializeAndInsertFromArena(pair->getKey().data);
}
if (use_null_map)
null_map.push_back(0);
}
using ArrayIntersect = FunctionArrayIntersect<ArrayModeIntersect>;
using ArrayUnion = FunctionArrayIntersect<ArrayModeUnion>;
REGISTER_FUNCTION(ArrayIntersect)
{
// factory.registerFunction<FunctionArrayIntersect>();
factory.registerFunction<ArrayIntersect>();
factory.registerFunction<ArrayUnion>();
}
}