#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int LOGICAL_ERROR; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int ILLEGAL_TYPE_OF_ARGUMENT; } class FunctionArrayIntersect : public IFunction { public: static constexpr auto name = "arrayIntersect"; static FunctionPtr create(const Context & context) { return std::make_shared(context); } FunctionArrayIntersect(const Context & context_) : context(context_) {} String getName() const override { return name; } bool isVariadic() const override { return true; } size_t getNumberOfArguments() const override { return 0; } DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override; void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override; bool useDefaultImplementationForConstants() const override { return true; } private: const Context & context; /// Initially allocate a piece of memory for 512 elements. NOTE: This is just a guess. static constexpr size_t INITIAL_SIZE_DEGREE = 9; struct UnpackedArrays { size_t base_rows = 0; struct UnpackedArray { bool is_const = false; const NullMap * null_map = nullptr; const NullMap * overflow_mask = nullptr; const ColumnArray::ColumnOffsets::Container * offsets = nullptr; const IColumn * nested_column = nullptr; }; std::vector args; Columns column_holders; UnpackedArrays() = default; }; /// Cast column to data_type removing nullable if data_type hasn't. /// It's expected that column can represent data_type after removing some NullMap's. ColumnPtr castRemoveNullable(const ColumnPtr & column, const DataTypePtr & data_type) const; struct CastArgumentsResult { ColumnsWithTypeAndName initial; ColumnsWithTypeAndName casted; }; CastArgumentsResult castColumns(Block & block, const ColumnNumbers & arguments, const DataTypePtr & return_type, const DataTypePtr & return_type_with_nulls) const; UnpackedArrays prepareArrays(const ColumnsWithTypeAndName & columns, ColumnsWithTypeAndName & initial_columns) const; template static ColumnPtr execute(const UnpackedArrays & arrays, MutableColumnPtr result_data); struct NumberExecutor { const UnpackedArrays & arrays; const DataTypePtr & data_type; ColumnPtr & result; NumberExecutor(const UnpackedArrays & arrays_, const DataTypePtr & data_type_, ColumnPtr & result_) : arrays(arrays_), data_type(data_type_), result(result_) {} template void operator()(); }; struct DecimalExecutor { const UnpackedArrays & arrays; const DataTypePtr & data_type; ColumnPtr & result; DecimalExecutor(const UnpackedArrays & arrays_, const DataTypePtr & data_type_, ColumnPtr & result_) : arrays(arrays_), data_type(data_type_), result(result_) {} template void operator()(); }; }; DataTypePtr FunctionArrayIntersect::getReturnTypeImpl(const DataTypes & arguments) const { DataTypes nested_types; nested_types.reserve(arguments.size()); bool has_nothing = false; if (arguments.empty()) throw Exception{"Function " + getName() + " requires at least one argument.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH}; for (auto i : ext::range(0, arguments.size())) { auto array_type = typeid_cast(arguments[i].get()); if (!array_type) throw Exception("Argument " + std::to_string(i) + " for function " + getName() + " must be an array but it has type " + arguments[i]->getName() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); const auto & nested_type = array_type->getNestedType(); if (typeid_cast(nested_type.get())) has_nothing = true; else nested_types.push_back(nested_type); } DataTypePtr result_type; if (!nested_types.empty()) result_type = getMostSubtype(nested_types, true); if (has_nothing) result_type = std::make_shared(); return std::make_shared(result_type); } ColumnPtr FunctionArrayIntersect::castRemoveNullable(const ColumnPtr & column, const DataTypePtr & data_type) const { if (auto column_nullable = checkAndGetColumn(column.get())) { auto nullable_type = checkAndGetDataType(data_type.get()); const auto & nested = column_nullable->getNestedColumnPtr(); if (nullable_type) { auto casted_column = castRemoveNullable(nested, nullable_type->getNestedType()); return ColumnNullable::create(casted_column, column_nullable->getNullMapColumnPtr()); } return castRemoveNullable(nested, data_type); } else if (auto column_array = checkAndGetColumn(column.get())) { auto array_type = checkAndGetDataType(data_type.get()); if (!array_type) throw Exception{"Cannot cast array column to column with type " + data_type->getName() + " in function " + getName(), ErrorCodes::LOGICAL_ERROR}; auto casted_column = castRemoveNullable(column_array->getDataPtr(), array_type->getNestedType()); return ColumnArray::create(casted_column, column_array->getOffsetsPtr()); } else if (auto column_tuple = checkAndGetColumn(column.get())) { auto tuple_type = checkAndGetDataType(data_type.get()); if (!tuple_type) throw Exception{"Cannot cast tuple column to type " + data_type->getName() + " in function " + getName(), ErrorCodes::LOGICAL_ERROR}; auto columns_number = column_tuple->tupleSize(); Columns columns(columns_number); const auto & types = tuple_type->getElements(); for (auto i : ext::range(0, columns_number)) { columns[i] = castRemoveNullable(column_tuple->getColumnPtr(i), types[i]); } return ColumnTuple::create(columns); } return column; } FunctionArrayIntersect::CastArgumentsResult FunctionArrayIntersect::castColumns( Block & block, const ColumnNumbers & arguments, const DataTypePtr & return_type, const DataTypePtr & return_type_with_nulls) const { size_t num_args = arguments.size(); ColumnsWithTypeAndName initial_columns(num_args); ColumnsWithTypeAndName columns(num_args); auto type_array = checkAndGetDataType(return_type.get()); auto & type_nested = type_array->getNestedType(); auto type_not_nullable_nested = removeNullable(type_nested); const bool is_numeric_or_string = isNativeNumber(type_not_nullable_nested) || isDateOrDateTime(type_not_nullable_nested) || isStringOrFixedString(type_not_nullable_nested); DataTypePtr nullable_return_type; if (is_numeric_or_string) { auto type_nullable_nested = makeNullable(type_nested); nullable_return_type = std::make_shared(type_nullable_nested); } const bool nested_is_nullable = type_nested->isNullable(); for (size_t i = 0; i < num_args; ++i) { const ColumnWithTypeAndName & arg = block.getByPosition(arguments[i]); initial_columns[i] = arg; columns[i] = arg; auto & column = columns[i]; if (is_numeric_or_string) { /// Cast to Array(T) or Array(Nullable(T)). if (nested_is_nullable) { if (!arg.type->equals(*return_type)) { column.column = castColumn(arg, return_type, context); column.type = return_type; } } else { if (!arg.type->equals(*return_type) && !arg.type->equals(*nullable_return_type)) { /// If result has array type Array(T) still cast Array(Nullable(U)) to Array(Nullable(T)) /// because cannot cast Nullable(T) to T. if (static_cast(*arg.type).getNestedType()->isNullable()) { column.column = castColumn(arg, nullable_return_type, context); column.type = nullable_return_type; } else { column.column = castColumn(arg, return_type, context); column.type = return_type; } } } } else { /// return_type_with_nulls is the most common subtype with possible nullable parts. if (!arg.type->equals(*return_type_with_nulls)) { column.column = castColumn(arg, return_type_with_nulls, context); column.type = return_type_with_nulls; } } } return {.initial = initial_columns, .casted = columns}; } static ColumnPtr callFunctionNotEquals(ColumnWithTypeAndName first, ColumnWithTypeAndName second, const Context & context) { ColumnsWithTypeAndName args; args.reserve(2); args.emplace_back(std::move(first)); args.emplace_back(std::move(second)); auto eq_func = FunctionFactory::instance().get("notEquals", context)->build(args); Block block = args; block.insert({nullptr, eq_func->getReturnType(), ""}); eq_func->execute(block, {0, 1}, 2, args.front().column->size()); return block.getByPosition(2).column; } FunctionArrayIntersect::UnpackedArrays FunctionArrayIntersect::prepareArrays( const ColumnsWithTypeAndName & columns, ColumnsWithTypeAndName & initial_columns) const { UnpackedArrays arrays; size_t columns_number = columns.size(); arrays.args.resize(columns_number); bool all_const = true; for (auto i : ext::range(0, columns_number)) { auto & arg = arrays.args[i]; auto argument_column = columns[i].column.get(); auto initial_column = initial_columns[i].column.get(); if (auto argument_column_const = typeid_cast(argument_column)) { arg.is_const = true; argument_column = argument_column_const->getDataColumnPtr().get(); initial_column = typeid_cast(initial_column)->getDataColumnPtr().get(); } if (auto argument_column_array = typeid_cast(argument_column)) { if (!arg.is_const) all_const = false; arg.offsets = &argument_column_array->getOffsets(); arg.nested_column = &argument_column_array->getData(); initial_column = &typeid_cast(initial_column)->getData(); if (auto column_nullable = typeid_cast(arg.nested_column)) { arg.null_map = &column_nullable->getNullMapData(); arg.nested_column = &column_nullable->getNestedColumn(); initial_column = &typeid_cast(initial_column)->getNestedColumn(); } /// In case column was casted need to create overflow mask for integer types. if (arg.nested_column != initial_column) { auto & nested_init_type = typeid_cast(removeNullable(initial_columns[i].type).get())->getNestedType(); auto & nested_cast_type = typeid_cast(removeNullable(columns[i].type).get())->getNestedType(); if (isInteger(nested_init_type) || isDateOrDateTime(nested_init_type)) { /// Compare original and casted columns. It seem to be the easiest way. auto overflow_mask = callFunctionNotEquals( {arg.nested_column->getPtr(), nested_init_type, ""}, {initial_column->getPtr(), nested_cast_type, ""}, context); arg.overflow_mask = &typeid_cast(overflow_mask.get())->getData(); arrays.column_holders.emplace_back(std::move(overflow_mask)); } } } else throw Exception{"Arguments for function " + getName() + " must be arrays.", ErrorCodes::LOGICAL_ERROR}; } if (all_const) { arrays.base_rows = arrays.args.front().offsets->size(); } else { for (auto i : ext::range(0, columns_number)) { if (arrays.args[i].is_const) continue; size_t rows = arrays.args[i].offsets->size(); if (arrays.base_rows == 0 && rows > 0) arrays.base_rows = rows; else if (arrays.base_rows != rows) throw Exception("Non-const array columns in function " + getName() + "should have same rows", ErrorCodes::LOGICAL_ERROR); } } return arrays; } void FunctionArrayIntersect::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) { const auto & return_type = block.getByPosition(result).type; auto return_type_array = checkAndGetDataType(return_type.get()); if (!return_type_array) throw Exception{"Return type for function " + getName() + " must be array.", ErrorCodes::LOGICAL_ERROR}; const auto & nested_return_type = return_type_array->getNestedType(); if (typeid_cast(nested_return_type.get())) { block.getByPosition(result).column = return_type->createColumnConstWithDefaultValue(input_rows_count); return; } auto num_args = arguments.size(); DataTypes data_types; data_types.reserve(num_args); for (size_t i = 0; i < num_args; ++i) data_types.push_back(block.getByPosition(arguments[i]).type); auto return_type_with_nulls = getMostSubtype(data_types, true, true); auto columns = castColumns(block, arguments, return_type, return_type_with_nulls); UnpackedArrays arrays = prepareArrays(columns.casted, columns.initial); ColumnPtr result_column; auto not_nullable_nested_return_type = removeNullable(nested_return_type); TypeListNativeNumbers::forEach(NumberExecutor(arrays, not_nullable_nested_return_type, result_column)); TypeListDecimalNumbers::forEach(DecimalExecutor(arrays, not_nullable_nested_return_type, result_column)); using DateMap = ClearableHashMap, HashTableGrower, HashTableAllocatorWithStackMemory<(1ULL << INITIAL_SIZE_DEGREE) * sizeof(DataTypeDate::FieldType)>>; using DateTimeMap = ClearableHashMap, HashTableGrower, HashTableAllocatorWithStackMemory<(1ULL << INITIAL_SIZE_DEGREE) * sizeof(DataTypeDateTime::FieldType)>>; using StringMap = ClearableHashMap, HashTableAllocatorWithStackMemory<(1ULL << INITIAL_SIZE_DEGREE) * sizeof(StringRef)>>; if (!result_column) { auto column = not_nullable_nested_return_type->createColumn(); WhichDataType which(not_nullable_nested_return_type); if (which.isDate()) result_column = execute, true>(arrays, std::move(column)); else if (which.isDateTime()) result_column = execute, true>(arrays, std::move(column)); else if (which.isString()) result_column = execute(arrays, std::move(column)); else if (which.isFixedString()) result_column = execute(arrays, std::move(column)); else { column = assert_cast(*return_type_with_nulls).getNestedType()->createColumn(); result_column = castRemoveNullable(execute(arrays, std::move(column)), return_type); } } block.getByPosition(result).column = std::move(result_column); } template void FunctionArrayIntersect::NumberExecutor::operator()() { using Map = ClearableHashMap, HashTableGrower, HashTableAllocatorWithStackMemory<(1ULL << INITIAL_SIZE_DEGREE) * sizeof(T)>>; if (!result && typeid_cast *>(data_type.get())) result = execute, true>(arrays, ColumnVector::create()); } template void FunctionArrayIntersect::DecimalExecutor::operator()() { using Map = ClearableHashMap, HashTableGrower, HashTableAllocatorWithStackMemory<(1ULL << INITIAL_SIZE_DEGREE) * sizeof(T)>>; if (!result) if (auto * decimal = typeid_cast *>(data_type.get())) result = execute, true>(arrays, ColumnDecimal::create(0, decimal->getScale())); } template ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, MutableColumnPtr result_data_ptr) { auto args = arrays.args.size(); auto rows = arrays.base_rows; bool all_nullable = true; std::vector columns; columns.reserve(args); for (auto & arg : arrays.args) { if constexpr (std::is_same::value) columns.push_back(arg.nested_column); else columns.push_back(checkAndGetColumn(arg.nested_column)); if (!columns.back()) throw Exception("Unexpected array type for function arrayIntersect", ErrorCodes::LOGICAL_ERROR); if (!arg.null_map) all_nullable = false; } auto & result_data = static_cast(*result_data_ptr); auto result_offsets_ptr = ColumnArray::ColumnOffsets::create(rows); auto & result_offsets = assert_cast(*result_offsets_ptr); auto null_map_column = ColumnUInt8::create(); NullMap & null_map = assert_cast(*null_map_column).getData(); Arena arena; Map map; std::vector prev_off(args, 0); size_t result_offset = 0; for (auto row : ext::range(0, rows)) { map.clear(); bool all_has_nullable = all_nullable; for (auto arg_num : ext::range(0, args)) { auto & arg = arrays.args[arg_num]; bool current_has_nullable = false; size_t off; // const array has only one row if (arg.is_const) off = (*arg.offsets)[0]; else off = (*arg.offsets)[row]; for (auto i : ext::range(prev_off[arg_num], off)) { if (arg.null_map && (*arg.null_map)[i]) current_has_nullable = true; else if (!arg.overflow_mask || (*arg.overflow_mask)[i] == 0) { typename Map::mapped_type * value = nullptr; if constexpr (is_numeric_column) value = &map[columns[arg_num]->getElement(i)]; else if constexpr (std::is_same::value || std::is_same::value) value = &map[columns[arg_num]->getDataAt(i)]; else { const char * data = nullptr; value = &map[columns[arg_num]->serializeValueIntoArena(i, arena, data)]; } /// Here we count the number of element appearances, but no more than once per array. if (*value == arg_num) ++(*value); } } prev_off[arg_num] = off; if (arg.is_const) prev_off[arg_num] = 0; if (!current_has_nullable) all_has_nullable = false; } if (all_has_nullable) { ++result_offset; result_data.insertDefault(); null_map.push_back(1); } for (const auto & pair : map) { if (pair.getMapped() == args) { ++result_offset; if constexpr (is_numeric_column) result_data.insertValue(pair.getKey()); else if constexpr (std::is_same::value || std::is_same::value) result_data.insertData(pair.getKey().data, pair.getKey().size); else result_data.deserializeAndInsertFromArena(pair.getKey().data); if (all_nullable) null_map.push_back(0); } } result_offsets.getElement(row) = result_offset; } ColumnPtr result_column = std::move(result_data_ptr); if (all_nullable) result_column = ColumnNullable::create(result_column, std::move(null_map_column)); return ColumnArray::create(result_column, std::move(result_offsets_ptr)); } void registerFunctionArrayIntersect(FunctionFactory & factory) { factory.registerFunction(); } }