mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-21 01:00:48 +00:00
Refactor FunctionArrayMapped to allow for a fixed number of extra positional arguments
This commit is contained in:
parent
300dfbbef6
commit
23d173a53c
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <span>
|
||||
#include <type_traits>
|
||||
|
||||
#include <Columns/ColumnArray.h>
|
||||
@ -75,6 +76,10 @@ const IColumn::Offsets & getOffsets(const T & column)
|
||||
* arrayMap(x1,...,xn -> expression, array1,...,arrayn) - apply the expression to each element of the array (or set of parallel arrays).
|
||||
* arrayFilter(x -> predicate, array) - leave in the array only the elements for which the expression is true.
|
||||
*
|
||||
* It is possible for the functions to require fixed number of positional arguments:
|
||||
* arrayPartialSort(limit, arr)
|
||||
* arrayPartialSort(x -> predicate, limit, arr)
|
||||
*
|
||||
* For some functions arrayCount, arrayExists, arrayAll, an overload of the form f(array) is available,
|
||||
* which works in the same way as f(x -> x, array).
|
||||
*
|
||||
@ -88,12 +93,13 @@ public:
|
||||
static constexpr bool is_argument_type_map = std::is_same_v<typename Impl::data_type, DataTypeMap>;
|
||||
static constexpr bool is_argument_type_array = std::is_same_v<typename Impl::data_type, DataTypeArray>;
|
||||
static constexpr auto argument_type_name = is_argument_type_map ? "Map" : "Array";
|
||||
|
||||
static constexpr bool has_num_fixed_params = requires(const Impl &) { Impl::num_fixed_params; };
|
||||
static constexpr size_t num_fixed_params = []{ if constexpr (has_num_fixed_params) return Impl::num_fixed_params; else return 0; }();
|
||||
|
||||
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayMapped>(); }
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
String getName() const override { return name; }
|
||||
|
||||
bool isVariadic() const override { return true; }
|
||||
size_t getNumberOfArguments() const override { return 0; }
|
||||
@ -104,30 +110,41 @@ public:
|
||||
void getLambdaArgumentTypes(DataTypes & arguments) const override
|
||||
{
|
||||
if (arguments.empty())
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Function {} needs at least one argument, passed {}", getName(), arguments.size());
|
||||
throw Exception(
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Function {} needs at least one argument, passed {}",
|
||||
getName(),
|
||||
arguments.size());
|
||||
|
||||
if (arguments.size() == 1)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Function {} needs at least one argument with data", getName());
|
||||
if (arguments.size() < 1 + num_fixed_params)
|
||||
throw Exception(
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Function {} needs at least {} argument{} with data",
|
||||
getName(),
|
||||
num_fixed_params + 1,
|
||||
(num_fixed_params + 1 == 1) ? "" : "s");
|
||||
|
||||
if (arguments.size() > 2 && Impl::needOneArray())
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Function {} needs one argument with data", getName());
|
||||
if (arguments.size() > 2 + num_fixed_params && Impl::needOneArray())
|
||||
throw Exception(
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Function {} needs {} argument{} with data",
|
||||
getName(),
|
||||
num_fixed_params + 1,
|
||||
(num_fixed_params + 1 == 1) ? "" : "s");
|
||||
|
||||
size_t nested_types_count = is_argument_type_map ? (arguments.size() - 1) * 2 : (arguments.size() - 1);
|
||||
size_t nested_types_count = (arguments.size() - num_fixed_params - 1) * (is_argument_type_map ? 2 : 1);
|
||||
DataTypes nested_types(nested_types_count);
|
||||
for (size_t i = 0; i < arguments.size() - 1; ++i)
|
||||
for (size_t i = 0; i < arguments.size() - 1 - num_fixed_params; ++i)
|
||||
{
|
||||
const auto * array_type = checkAndGetDataType<typename Impl::data_type>(&*arguments[i + 1]);
|
||||
const auto * array_type = checkAndGetDataType<typename Impl::data_type>(&*arguments[i + 1 + num_fixed_params]);
|
||||
if (!array_type)
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Argument {} of function {} must be {}. Found {} instead",
|
||||
toString(i + 2),
|
||||
i + 2 + num_fixed_params,
|
||||
getName(),
|
||||
argument_type_name,
|
||||
arguments[i + 1]->getName());
|
||||
arguments[i + 1 + num_fixed_params]->getName());
|
||||
if constexpr (is_argument_type_map)
|
||||
{
|
||||
nested_types[2 * i] = recursiveRemoveLowCardinality(array_type->getKeyType());
|
||||
@ -144,32 +161,54 @@ public:
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"First argument for this overload of {} must be a function with {} arguments, found {} instead",
|
||||
getName(), nested_types.size(), arguments[0]->getName());
|
||||
getName(),
|
||||
nested_types.size(),
|
||||
arguments[0]->getName());
|
||||
|
||||
arguments[0] = std::make_shared<DataTypeFunction>(nested_types);
|
||||
}
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
|
||||
{
|
||||
size_t min_args = Impl::needExpression() ? 2 : 1;
|
||||
size_t min_args = (num_fixed_params + Impl::needExpression()) ? 2 : 1;
|
||||
if (arguments.size() < min_args)
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Function {} needs at least {} argument, passed {}",
|
||||
getName(), min_args, arguments.size());
|
||||
throw Exception(
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
|
||||
"Function {} needs at least {} argument{}, passed {}",
|
||||
getName(),
|
||||
min_args,
|
||||
(min_args > 1 ? "s" : ""),
|
||||
arguments.size());
|
||||
|
||||
if ((arguments.size() == 1) && is_argument_type_array)
|
||||
if ((arguments.size() == 1 + num_fixed_params) && is_argument_type_array)
|
||||
{
|
||||
const auto * data_type = checkAndGetDataType<typename Impl::data_type>(arguments[0].type.get());
|
||||
const auto * data_type = checkAndGetDataType<typename Impl::data_type>(arguments[num_fixed_params].type.get());
|
||||
|
||||
if (!data_type)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The only argument for function {} must be array. "
|
||||
"Found {} instead", getName(), arguments[0].type->getName());
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"The {}{}{} argument for function {} must be array. Found {} instead",
|
||||
num_fixed_params + 1,
|
||||
getOrdinalSuffix(num_fixed_params + 1),
|
||||
(num_fixed_params == 0 ? " and only" : ""),
|
||||
getName(),
|
||||
arguments[num_fixed_params].type->getName());
|
||||
|
||||
if constexpr (num_fixed_params)
|
||||
Impl::checkArguments(
|
||||
std::span<const ColumnWithTypeAndName, num_fixed_params>(std::begin(arguments), num_fixed_params), getName());
|
||||
|
||||
DataTypePtr nested_type = data_type->getNestedType();
|
||||
|
||||
if (Impl::needBoolean() && !isUInt8(nested_type))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The only argument for function {} must be array of UInt8. "
|
||||
"Found {} instead", getName(), arguments[0].type->getName());
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"The {}{}{} argument for function {} must be array of UInt8. Found {} instead",
|
||||
num_fixed_params + 1,
|
||||
getOrdinalSuffix(num_fixed_params + 1),
|
||||
(num_fixed_params == 0 ? " and only" : ""),
|
||||
getName(),
|
||||
arguments[num_fixed_params].type->getName());
|
||||
|
||||
if constexpr (is_argument_type_array)
|
||||
return Impl::getReturnType(nested_type, nested_type);
|
||||
@ -178,17 +217,22 @@ public:
|
||||
}
|
||||
else
|
||||
{
|
||||
if (arguments.size() > 2 && Impl::needOneArray())
|
||||
if (arguments.size() > 2 + num_fixed_params && Impl::needOneArray())
|
||||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} needs one argument with data", getName());
|
||||
|
||||
const auto * data_type_function = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
|
||||
|
||||
if (!data_type_function)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"First argument for function {} must be a function. Actual {}",
|
||||
getName(),
|
||||
arguments[0].type->getName());
|
||||
|
||||
if constexpr (num_fixed_params)
|
||||
Impl::checkArguments(
|
||||
std::span<const ColumnWithTypeAndName, num_fixed_params>(std::begin(arguments) + 1, num_fixed_params), getName());
|
||||
|
||||
/// The types of the remaining arguments are already checked in getLambdaArgumentTypes.
|
||||
|
||||
DataTypePtr return_type = removeLowCardinality(data_type_function->getReturnType());
|
||||
@ -199,21 +243,24 @@ public:
|
||||
/// - lambda may return Nothing or Nullable(Nothing) because of default implementation of functions
|
||||
/// for these types. In this case we will just create UInt8 const column full of 0.
|
||||
if (Impl::needBoolean() && !isUInt8(removeNullable(return_type)) && !isNothing(removeNullable(return_type)))
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Expression for function {} must return UInt8 or Nullable(UInt8), found {}",
|
||||
getName(), return_type->getName());
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Expression for function {} must return UInt8 or Nullable(UInt8), found {}",
|
||||
getName(),
|
||||
return_type->getName());
|
||||
|
||||
static_assert(is_argument_type_map || is_argument_type_array, "unsupported type");
|
||||
|
||||
if (arguments.size() < 2)
|
||||
if (arguments.size() < 2 + num_fixed_params)
|
||||
{
|
||||
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Incorrect number of arguments: {}", arguments.size());
|
||||
}
|
||||
|
||||
const auto * first_array_type = checkAndGetDataType<typename Impl::data_type>(arguments[1].type.get());
|
||||
const auto * first_array_type = checkAndGetDataType<typename Impl::data_type>(arguments[1 + num_fixed_params].type.get());
|
||||
|
||||
if (!first_array_type)
|
||||
throw DB::Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Unsupported type {}", arguments[1].type->getName());
|
||||
throw DB::Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Unsupported type {}", arguments[1 + num_fixed_params].type->getName());
|
||||
|
||||
if constexpr (is_argument_type_array)
|
||||
return Impl::getReturnType(return_type, first_array_type->getNestedType());
|
||||
@ -227,9 +274,9 @@ public:
|
||||
|
||||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override
|
||||
{
|
||||
if (arguments.size() == 1)
|
||||
if (arguments.size() == 1 + num_fixed_params)
|
||||
{
|
||||
ColumnPtr column_array_ptr = arguments[0].column;
|
||||
ColumnPtr column_array_ptr = arguments[num_fixed_params].column;
|
||||
const auto * column_array = checkAndGetColumn<typename Impl::column_type>(column_array_ptr.get());
|
||||
|
||||
if (!column_array)
|
||||
@ -237,21 +284,30 @@ public:
|
||||
const ColumnConst * column_const_array = checkAndGetColumnConst<typename Impl::column_type>(column_array_ptr.get());
|
||||
if (!column_const_array)
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_COLUMN,
|
||||
"Expected {} column, found {}",
|
||||
argument_type_name,
|
||||
column_array_ptr->getName());
|
||||
ErrorCodes::ILLEGAL_COLUMN, "Expected {} column, found {}", argument_type_name, column_array_ptr->getName());
|
||||
column_array_ptr = column_const_array->convertToFullColumn();
|
||||
column_array = assert_cast<const typename Impl::column_type *>(column_array_ptr.get());
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<typename Impl::column_type, ColumnMap>)
|
||||
{
|
||||
return Impl::execute(*column_array, column_array->getNestedColumn().getDataPtr());
|
||||
if constexpr (num_fixed_params)
|
||||
return Impl::execute(
|
||||
*column_array,
|
||||
column_array->getNestedColumn().getDataPtr(),
|
||||
std::span<const ColumnWithTypeAndName, num_fixed_params>(std::begin(arguments), num_fixed_params));
|
||||
else
|
||||
return Impl::execute(*column_array, column_array->getNestedColumn().getDataPtr());
|
||||
}
|
||||
else
|
||||
{
|
||||
return Impl::execute(*column_array, column_array->getDataPtr());
|
||||
if constexpr (num_fixed_params)
|
||||
return Impl::execute(
|
||||
*column_array,
|
||||
column_array->getDataPtr(),
|
||||
std::span<const ColumnWithTypeAndName, num_fixed_params>(std::begin(arguments), num_fixed_params));
|
||||
else
|
||||
return Impl::execute(*column_array, column_array->getDataPtr());
|
||||
}
|
||||
}
|
||||
else
|
||||
@ -274,7 +330,7 @@ public:
|
||||
ColumnsWithTypeAndName arrays;
|
||||
arrays.reserve(arguments.size() - 1);
|
||||
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
for (size_t i = 1 + num_fixed_params; i < arguments.size(); ++i)
|
||||
{
|
||||
const auto & array_with_type_and_name = arguments[i];
|
||||
|
||||
@ -314,7 +370,7 @@ public:
|
||||
getName());
|
||||
}
|
||||
|
||||
if (i == 1)
|
||||
if (i == 1 + num_fixed_params)
|
||||
{
|
||||
column_first_array_ptr = column_array_ptr;
|
||||
column_first_array = column_array;
|
||||
@ -380,7 +436,13 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
return Impl::execute(*column_first_array, lambda_result.column);
|
||||
if constexpr (num_fixed_params)
|
||||
return Impl::execute(
|
||||
*column_first_array,
|
||||
lambda_result.column,
|
||||
std::span<const ColumnWithTypeAndName, num_fixed_params>(std::begin(arguments) + 1, num_fixed_params));
|
||||
else
|
||||
return Impl::execute(*column_first_array, lambda_result.column);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user