diff --git a/src/Functions/array/arrayIntersect.cpp b/src/Functions/array/arrayIntersect.cpp index 209441eb301..8affe1ac11c 100644 --- a/src/Functions/array/arrayIntersect.cpp +++ b/src/Functions/array/arrayIntersect.cpp @@ -35,10 +35,21 @@ namespace ErrorCodes extern const int ILLEGAL_TYPE_OF_ARGUMENT; } +struct ArrayModeIntersect +{ + static constexpr auto name = "arrayIntersect"; +}; + +struct ArrayModeUnion +{ + static constexpr auto name = "arrayUnion"; +}; + +template class FunctionArrayIntersect : public IFunction { public: - static constexpr auto name = "arrayIntersect"; + static constexpr auto name = Mode::name; static FunctionPtr create(ContextPtr context) { return std::make_shared(context); } explicit FunctionArrayIntersect(ContextPtr context_) : context(context_) {} @@ -124,8 +135,8 @@ private: }; }; - -DataTypePtr FunctionArrayIntersect::getReturnTypeImpl(const DataTypes & arguments) const +template +DataTypePtr FunctionArrayIntersect::getReturnTypeImpl(const DataTypes & arguments) const { DataTypes nested_types; nested_types.reserve(arguments.size()); @@ -162,7 +173,8 @@ DataTypePtr FunctionArrayIntersect::getReturnTypeImpl(const DataTypes & argument return std::make_shared(result_type); } -ColumnPtr FunctionArrayIntersect::castRemoveNullable(const ColumnPtr & column, const DataTypePtr & data_type) const +template +ColumnPtr FunctionArrayIntersect::castRemoveNullable(const ColumnPtr & column, const DataTypePtr & data_type) const { if (const auto * column_nullable = checkAndGetColumn(column.get())) { @@ -208,7 +220,8 @@ ColumnPtr FunctionArrayIntersect::castRemoveNullable(const ColumnPtr & column, c return column; } -FunctionArrayIntersect::CastArgumentsResult FunctionArrayIntersect::castColumns( +template +FunctionArrayIntersect::CastArgumentsResult FunctionArrayIntersect::castColumns( const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type, const DataTypePtr & return_type_with_nulls) { size_t num_args = arguments.size(); @@ -294,7 +307,8 @@ static ColumnPtr callFunctionNotEquals(ColumnWithTypeAndName first, ColumnWithTy return eq_func->execute(args, eq_func->getResultType(), args.front().column->size()); } -FunctionArrayIntersect::UnpackedArrays FunctionArrayIntersect::prepareArrays( +template +FunctionArrayIntersect::UnpackedArrays FunctionArrayIntersect::prepareArrays( const ColumnsWithTypeAndName & columns, ColumnsWithTypeAndName & initial_columns) const { UnpackedArrays arrays; @@ -384,7 +398,8 @@ FunctionArrayIntersect::UnpackedArrays FunctionArrayIntersect::prepareArrays( return arrays; } -ColumnPtr FunctionArrayIntersect::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const +template +ColumnPtr FunctionArrayIntersect::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const { const auto * return_type_array = checkAndGetDataType(result_type.get()); @@ -450,8 +465,9 @@ ColumnPtr FunctionArrayIntersect::executeImpl(const ColumnsWithTypeAndName & arg return result_column; } +template template -void FunctionArrayIntersect::NumberExecutor::operator()(TypeList) +void FunctionArrayIntersect::NumberExecutor::operator()(TypeList) { using Container = ClearableHashMapWithStackMemory, INITIAL_SIZE_DEGREE>; @@ -460,8 +476,9 @@ void FunctionArrayIntersect::NumberExecutor::operator()(TypeList) result = execute, true>(arrays, ColumnVector::create()); } +template template -void FunctionArrayIntersect::DecimalExecutor::operator()(TypeList) +void FunctionArrayIntersect::DecimalExecutor::operator()(TypeList) { using Container = ClearableHashMapWithStackMemory, INITIAL_SIZE_DEGREE>; @@ -471,8 +488,9 @@ void FunctionArrayIntersect::DecimalExecutor::operator()(TypeList) result = execute, true>(arrays, ColumnDecimal::create(0, decimal->getScale())); } +template template -ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, MutableColumnPtr result_data_ptr) +ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, MutableColumnPtr result_data_ptr) { auto args = arrays.args.size(); auto rows = arrays.base_rows; @@ -641,9 +659,13 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, Mutable } +using ArrayIntersect = FunctionArrayIntersect; +using ArrayUnion = FunctionArrayIntersect; + REGISTER_FUNCTION(ArrayIntersect) { - factory.registerFunction(); + // factory.registerFunction(); + factory.registerFunction(); } }