#pragma once #include #include #include #include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } /** Higher-order functions for arrays. * These functions optionally apply a map (transform) to array (or multiple arrays of identical size) by lambda function, * and return some result based on that transformation. * * Examples: * arrayMap(x1,...,xn -> expression, array1,...,arrayn) - apply the expression to each element of the array (or set of parallel arrays). * arrayFilter(x -> predicate, array) - leave in the array only the elements for which the expression is true. * * For some functions arrayCount, arrayExists, arrayAll, an overload of the form f(array) is available, which works in the same way as f(x -> x, array). * * See the example of Impl template parameter in arrayMap.cpp */ template class FunctionArrayMapped : public IFunction { public: static constexpr auto name = Name::name; static FunctionPtr create(const Context &) { return std::make_shared(); } String getName() const override { return name; } bool isVariadic() const override { return true; } size_t getNumberOfArguments() const override { return 0; } /// Called if at least one function argument is a lambda expression. /// For argument-lambda expressions, it defines the types of arguments of these expressions. void getLambdaArgumentTypes(DataTypes & arguments) const override { if (arguments.size() < 1) throw Exception("Function " + getName() + " needs at least one argument; passed " + toString(arguments.size()) + ".", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); if (arguments.size() == 1) throw Exception("Function " + getName() + " needs at least one array argument.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); DataTypes nested_types(arguments.size() - 1); for (size_t i = 0; i < nested_types.size(); ++i) { const DataTypeArray * array_type = checkAndGetDataType(&*arguments[i + 1]); if (!array_type) throw Exception("Argument " + toString(i + 2) + " of function " + getName() + " must be array. Found " + arguments[i + 1]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); nested_types[i] = removeLowCardinality(array_type->getNestedType()); } const DataTypeFunction * function_type = checkAndGetDataType(arguments[0].get()); if (!function_type || function_type->getArgumentTypes().size() != nested_types.size()) throw Exception("First argument for this overload of " + getName() + " must be a function with " + toString(nested_types.size()) + " arguments. Found " + arguments[0]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); arguments[0] = std::make_shared(nested_types); } DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { size_t min_args = Impl::needExpression() ? 2 : 1; if (arguments.size() < min_args) throw Exception("Function " + getName() + " needs at least " + toString(min_args) + " argument; passed " + toString(arguments.size()) + ".", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); if (arguments.size() == 1) { const auto array_type = checkAndGetDataType(arguments[0].type.get()); if (!array_type) throw Exception("The only argument for function " + getName() + " must be array. Found " + arguments[0].type->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); DataTypePtr nested_type = array_type->getNestedType(); if (Impl::needBoolean() && !WhichDataType(nested_type).isUInt8()) throw Exception("The only argument for function " + getName() + " must be array of UInt8. Found " + arguments[0].type->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); return Impl::getReturnType(nested_type, nested_type); } else { if (arguments.size() > 2 && Impl::needOneArray()) throw Exception("Function " + getName() + " needs one array argument.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); const auto data_type_function = checkAndGetDataType(arguments[0].type.get()); if (!data_type_function) throw Exception("First argument for function " + getName() + " must be a function.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); /// The types of the remaining arguments are already checked in getLambdaArgumentTypes. DataTypePtr return_type = removeLowCardinality(data_type_function->getReturnType()); if (Impl::needBoolean() && !WhichDataType(return_type).isUInt8()) throw Exception("Expression for function " + getName() + " must return UInt8, found " + return_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); const auto first_array_type = checkAndGetDataType(arguments[1].type.get()); return Impl::getReturnType(return_type, first_array_type->getNestedType()); } } void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override { if (arguments.size() == 1) { ColumnPtr column_array_ptr = block.getByPosition(arguments[0]).column; const auto * column_array = checkAndGetColumn(column_array_ptr.get()); if (!column_array) { const ColumnConst * column_const_array = checkAndGetColumnConst(column_array_ptr.get()); if (!column_const_array) throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN); column_array_ptr = column_const_array->convertToFullColumn(); column_array = assert_cast(column_array_ptr.get()); } block.getByPosition(result).column = Impl::execute(*column_array, column_array->getDataPtr()); } else { const auto & column_with_type_and_name = block.getByPosition(arguments[0]); if (!column_with_type_and_name.column) throw Exception("First argument for function " + getName() + " must be a function.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); const auto * column_function = typeid_cast(column_with_type_and_name.column.get()); if (!column_function) throw Exception("First argument for function " + getName() + " must be a function.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ColumnPtr offsets_column; ColumnPtr column_first_array_ptr; const ColumnArray * column_first_array = nullptr; ColumnsWithTypeAndName arrays; arrays.reserve(arguments.size() - 1); for (size_t i = 1; i < arguments.size(); ++i) { const auto & array_with_type_and_name = block.getByPosition(arguments[i]); ColumnPtr column_array_ptr = array_with_type_and_name.column; const auto * column_array = checkAndGetColumn(column_array_ptr.get()); const DataTypePtr & array_type_ptr = array_with_type_and_name.type; const auto * array_type = checkAndGetDataType(array_type_ptr.get()); if (!column_array) { const ColumnConst * column_const_array = checkAndGetColumnConst(column_array_ptr.get()); if (!column_const_array) throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN); column_array_ptr = column_const_array->convertToFullColumn(); if (column_array_ptr->lowCardinality()) column_array_ptr = column_array_ptr->convertToFullColumnIfLowCardinality(); column_array = checkAndGetColumn(column_array_ptr.get()); } if (!array_type) throw Exception("Expected array type, found " + array_type_ptr->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); if (!offsets_column) { offsets_column = column_array->getOffsetsPtr(); } else { /// The first condition is optimization: do not compare data if the pointers are equal. if (column_array->getOffsetsPtr() != offsets_column && column_array->getOffsets() != typeid_cast(*offsets_column).getData()) throw Exception("Arrays passed to " + getName() + " must have equal size", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); } if (i == 1) { column_first_array_ptr = column_array_ptr; column_first_array = column_array; } arrays.emplace_back(ColumnWithTypeAndName(column_array->getDataPtr(), removeLowCardinality(array_type->getNestedType()), array_with_type_and_name.name)); } /// Put all the necessary columns multiplied by the sizes of arrays into the block. auto replicated_column_function_ptr = (*column_function->replicate(column_first_array->getOffsets())).mutate(); auto * replicated_column_function = typeid_cast(replicated_column_function_ptr.get()); replicated_column_function->appendArguments(arrays); auto lambda_result = replicated_column_function->reduce().column; if (lambda_result->lowCardinality()) lambda_result = lambda_result->convertToFullColumnIfLowCardinality(); block.getByPosition(result).column = Impl::execute(*column_first_array, lambda_result); } } }; }