Templatize FunctionArrayIntersect with new mode types

This commit is contained in:
Peter Nguyen 2024-08-21 17:54:38 -06:00
parent dc4281cc86
commit 1dcfaa91c2

View File

@ -35,10 +35,21 @@ namespace ErrorCodes
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
struct ArrayModeIntersect
{
static constexpr auto name = "arrayIntersect";
};
struct ArrayModeUnion
{
static constexpr auto name = "arrayUnion";
};
template <typename Mode>
class FunctionArrayIntersect : public IFunction
{
public:
static constexpr auto name = "arrayIntersect";
static constexpr auto name = Mode::name;
static FunctionPtr create(ContextPtr context) { return std::make_shared<FunctionArrayIntersect>(context); }
explicit FunctionArrayIntersect(ContextPtr context_) : context(context_) {}
@ -124,8 +135,8 @@ private:
};
};
DataTypePtr FunctionArrayIntersect::getReturnTypeImpl(const DataTypes & arguments) const
template <typename Mode>
DataTypePtr FunctionArrayIntersect<Mode>::getReturnTypeImpl(const DataTypes & arguments) const
{
DataTypes nested_types;
nested_types.reserve(arguments.size());
@ -162,7 +173,8 @@ DataTypePtr FunctionArrayIntersect::getReturnTypeImpl(const DataTypes & argument
return std::make_shared<DataTypeArray>(result_type);
}
ColumnPtr FunctionArrayIntersect::castRemoveNullable(const ColumnPtr & column, const DataTypePtr & data_type) const
template <typename Mode>
ColumnPtr FunctionArrayIntersect<Mode>::castRemoveNullable(const ColumnPtr & column, const DataTypePtr & data_type) const
{
if (const auto * column_nullable = checkAndGetColumn<ColumnNullable>(column.get()))
{
@ -208,7 +220,8 @@ ColumnPtr FunctionArrayIntersect::castRemoveNullable(const ColumnPtr & column, c
return column;
}
FunctionArrayIntersect::CastArgumentsResult FunctionArrayIntersect::castColumns(
template <typename Mode>
FunctionArrayIntersect<Mode>::CastArgumentsResult FunctionArrayIntersect<Mode>::castColumns(
const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type, const DataTypePtr & return_type_with_nulls)
{
size_t num_args = arguments.size();
@ -294,7 +307,8 @@ static ColumnPtr callFunctionNotEquals(ColumnWithTypeAndName first, ColumnWithTy
return eq_func->execute(args, eq_func->getResultType(), args.front().column->size());
}
FunctionArrayIntersect::UnpackedArrays FunctionArrayIntersect::prepareArrays(
template <typename Mode>
FunctionArrayIntersect<Mode>::UnpackedArrays FunctionArrayIntersect<Mode>::prepareArrays(
const ColumnsWithTypeAndName & columns, ColumnsWithTypeAndName & initial_columns) const
{
UnpackedArrays arrays;
@ -384,7 +398,8 @@ FunctionArrayIntersect::UnpackedArrays FunctionArrayIntersect::prepareArrays(
return arrays;
}
ColumnPtr FunctionArrayIntersect::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
template <typename Mode>
ColumnPtr FunctionArrayIntersect<Mode>::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const
{
const auto * return_type_array = checkAndGetDataType<DataTypeArray>(result_type.get());
@ -450,8 +465,9 @@ ColumnPtr FunctionArrayIntersect::executeImpl(const ColumnsWithTypeAndName & arg
return result_column;
}
template <typename Mode>
template <class T>
void FunctionArrayIntersect::NumberExecutor::operator()(TypeList<T>)
void FunctionArrayIntersect<Mode>::NumberExecutor::operator()(TypeList<T>)
{
using Container = ClearableHashMapWithStackMemory<T, size_t, DefaultHash<T>,
INITIAL_SIZE_DEGREE>;
@ -460,8 +476,9 @@ void FunctionArrayIntersect::NumberExecutor::operator()(TypeList<T>)
result = execute<Container, ColumnVector<T>, true>(arrays, ColumnVector<T>::create());
}
template <typename Mode>
template <class T>
void FunctionArrayIntersect::DecimalExecutor::operator()(TypeList<T>)
void FunctionArrayIntersect<Mode>::DecimalExecutor::operator()(TypeList<T>)
{
using Container = ClearableHashMapWithStackMemory<T, size_t, DefaultHash<T>,
INITIAL_SIZE_DEGREE>;
@ -471,8 +488,9 @@ void FunctionArrayIntersect::DecimalExecutor::operator()(TypeList<T>)
result = execute<Container, ColumnDecimal<T>, true>(arrays, ColumnDecimal<T>::create(0, decimal->getScale()));
}
template <typename Mode>
template <typename Map, typename ColumnType, bool is_numeric_column>
ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, MutableColumnPtr result_data_ptr)
ColumnPtr FunctionArrayIntersect<Mode>::execute(const UnpackedArrays & arrays, MutableColumnPtr result_data_ptr)
{
auto args = arrays.args.size();
auto rows = arrays.base_rows;
@ -641,9 +659,13 @@ ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, Mutable
}
using ArrayIntersect = FunctionArrayIntersect<ArrayModeIntersect>;
using ArrayUnion = FunctionArrayIntersect<ArrayModeUnion>;
REGISTER_FUNCTION(ArrayIntersect)
{
factory.registerFunction<FunctionArrayIntersect>();
// factory.registerFunction<FunctionArrayIntersect>();
factory.registerFunction<ArrayIntersect>();
}
}