#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int LOGICAL_ERROR; extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } template ColumnPtr getOffsetsPtr(const T & column) { if constexpr (std::is_same_v) { return column.getOffsetsPtr(); } else // ColumnMap { return column.getNestedColumn().getOffsetsPtr(); } } template const IColumn::Offsets & getOffsets(const T & column) { if constexpr (std::is_same_v) { return column.getOffsets(); } else // ColumnMap { return column.getNestedColumn().getOffsets(); } } /** Higher-order functions for arrays. * These functions optionally apply a map (transform) to array (or multiple arrays of identical size) by lambda function, * and return some result based on that transformation. * * Examples: * arrayMap(x1,...,xn -> expression, array1,...,arrayn) - apply the expression to each element of the array (or set of parallel arrays). * arrayFilter(x -> predicate, array) - leave in the array only the elements for which the expression is true. * * For some functions arrayCount, arrayExists, arrayAll, an overload of the form f(array) is available, * which works in the same way as f(x -> x, array). * * See the example of Impl template parameter in arrayMap.cpp */ template class FunctionArrayMapped : public IFunction { public: static constexpr auto name = Name::name; static FunctionPtr create(ContextPtr) { return std::make_shared(); } String getName() const override { return name; } bool isVariadic() const override { return true; } size_t getNumberOfArguments() const override { return 0; } bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } /// Called if at least one function argument is a lambda expression. /// For argument-lambda expressions, it defines the types of arguments of these expressions. void getLambdaArgumentTypes(DataTypes & arguments) const override { if (arguments.empty()) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} needs at least one argument, passed {}", getName(), arguments.size()); if (arguments.size() == 1) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} needs at least one argument with data", getName()); if (arguments.size() > 2 && Impl::needOneArray()) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} needs one argument with data", getName()); size_t nested_types_count = std::is_same_v ? (arguments.size() - 1) * 2 : (arguments.size() - 1); DataTypes nested_types(nested_types_count); for (size_t i = 0; i < arguments.size() - 1; ++i) { const auto * array_type = checkAndGetDataType(&*arguments[i + 1]); if (!array_type) throw Exception("Argument " + toString(i + 2) + " of function " + getName() + " must be array. Found " + arguments[i + 1]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); if constexpr (std::is_same_v) { nested_types[2 * i] = recursiveRemoveLowCardinality(array_type->getKeyType()); nested_types[2 * i + 1] = recursiveRemoveLowCardinality(array_type->getValueType()); } else if constexpr (std::is_same_v) { nested_types[i] = recursiveRemoveLowCardinality(array_type->getNestedType()); } } const DataTypeFunction * function_type = checkAndGetDataType(arguments[0].get()); if (!function_type || function_type->getArgumentTypes().size() != nested_types.size()) throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for this overload of {} must be a function with {} arguments, found {} instead", getName(), nested_types.size(), arguments[0]->getName()); arguments[0] = std::make_shared(nested_types); } DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { size_t min_args = Impl::needExpression() ? 2 : 1; if (arguments.size() < min_args) throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} needs at least {} argument, passed {}", getName(), min_args, arguments.size()); if ((arguments.size() == 1) && std::is_same_v) { const auto * data_type = checkAndGetDataType(arguments[0].type.get()); if (!data_type) throw Exception("The only argument for function " + getName() + " must be array. Found " + arguments[0].type->getName() + " instead", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); DataTypePtr nested_type = data_type->getNestedType(); if (Impl::needBoolean() && !isUInt8(nested_type)) throw Exception("The only argument for function " + getName() + " must be array of UInt8. Found " + arguments[0].type->getName() + " instead", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); if constexpr (std::is_same_v) return Impl::getReturnType(nested_type, nested_type); else throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Unreachable code reached"); } else { if (arguments.size() > 2 && Impl::needOneArray()) throw Exception("Function " + getName() + " needs one argument with data", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); const auto * data_type_function = checkAndGetDataType(arguments[0].type.get()); if (!data_type_function) throw Exception("First argument for function " + getName() + " must be a function", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); /// The types of the remaining arguments are already checked in getLambdaArgumentTypes. DataTypePtr return_type = removeLowCardinality(data_type_function->getReturnType()); /// Special cases when we need boolean lambda result: /// - lambda may return Nullable(UInt8) column, in this case after lambda execution we will /// replace all NULLs with 0 and return nested UInt8 column. /// - lambda may return Nothing or Nullable(Nothing) because of default implementation of functions /// for these types. In this case we will just create UInt8 const column full of 0. if (Impl::needBoolean() && !isUInt8(removeNullable(return_type)) && !isNothing(removeNullable(return_type))) throw Exception("Expression for function " + getName() + " must return UInt8 or Nullable(UInt8), found " + return_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); static_assert( std::is_same_v || std::is_same_v, "unsupported type"); if (arguments.size() < 2) { throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "{}", arguments.size()); } const auto * first_array_type = checkAndGetDataType(arguments[1].type.get()); if (!first_array_type) throw DB::Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Unsupported type {}", arguments[1].type->getName()); if constexpr (std::is_same_v) return Impl::getReturnType(return_type, first_array_type->getNestedType()); if constexpr (std::is_same_v) return Impl::getReturnType(return_type, first_array_type->getKeyValueTypes()); throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Unreachable code reached"); } } ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override { if (arguments.size() == 1) { ColumnPtr column_array_ptr = arguments[0].column; const auto * column_array = checkAndGetColumn(column_array_ptr.get()); if (!column_array) { const ColumnConst * column_const_array = checkAndGetColumnConst(column_array_ptr.get()); if (!column_const_array) throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN); column_array_ptr = column_const_array->convertToFullColumn(); column_array = assert_cast(column_array_ptr.get()); } if constexpr (std::is_same_v) { return Impl::execute(*column_array, column_array->getNestedColumn().getDataPtr()); } else { return Impl::execute(*column_array, column_array->getDataPtr()); } } else { const auto & column_with_type_and_name = arguments[0]; if (!column_with_type_and_name.column) throw Exception("First argument for function " + getName() + " must be a function.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); const auto * column_function = typeid_cast(column_with_type_and_name.column.get()); if (!column_function) throw Exception("First argument for function " + getName() + " must be a function.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ColumnPtr offsets_column; ColumnPtr column_first_array_ptr; const typename Impl::column_type * column_first_array = nullptr; ColumnsWithTypeAndName arrays; arrays.reserve(arguments.size() - 1); for (size_t i = 1; i < arguments.size(); ++i) { const auto & array_with_type_and_name = arguments[i]; ColumnPtr column_array_ptr = array_with_type_and_name.column; const auto * column_array = checkAndGetColumn(column_array_ptr.get()); const DataTypePtr & array_type_ptr = array_with_type_and_name.type; const auto * array_type = checkAndGetDataType(array_type_ptr.get()); if (!column_array) { const ColumnConst * column_const_array = checkAndGetColumnConst(column_array_ptr.get()); if (!column_const_array) throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN); column_array_ptr = recursiveRemoveLowCardinality(column_const_array->convertToFullColumn()); column_array = checkAndGetColumn(column_array_ptr.get()); } if (!array_type) throw Exception("Expected array type, found " + array_type_ptr->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); if (!offsets_column) { offsets_column = getOffsetsPtr(*column_array); } else { /// The first condition is optimization: do not compare data if the pointers are equal. if (getOffsetsPtr(*column_array) != offsets_column && getOffsets(*column_array) != typeid_cast(*offsets_column).getData()) throw Exception("Arrays passed to " + getName() + " must have equal size", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); } if (i == 1) { column_first_array_ptr = column_array_ptr; column_first_array = column_array; } if constexpr (std::is_same_v) { arrays.emplace_back(ColumnWithTypeAndName( column_array->getNestedData().getColumnPtr(0), recursiveRemoveLowCardinality(array_type->getKeyType()), array_with_type_and_name.name+".key")); arrays.emplace_back(ColumnWithTypeAndName( column_array->getNestedData().getColumnPtr(1), recursiveRemoveLowCardinality(array_type->getValueType()), array_with_type_and_name.name+".value")); } else { arrays.emplace_back(ColumnWithTypeAndName(column_array->getDataPtr(), recursiveRemoveLowCardinality(array_type->getNestedType()), array_with_type_and_name.name)); } } /// Put all the necessary columns multiplied by the sizes of arrays into the columns. auto replicated_column_function_ptr = IColumn::mutate(column_function->replicate(getOffsets(*column_first_array))); auto * replicated_column_function = typeid_cast(replicated_column_function_ptr.get()); replicated_column_function->appendArguments(arrays); auto lambda_result = replicated_column_function->reduce(); if (lambda_result.column->lowCardinality()) lambda_result.column = lambda_result.column->convertToFullColumnIfLowCardinality(); if (Impl::needBoolean()) { /// If result column is Nothing or Nullable(Nothing), just create const UInt8 column with 0 value. if (isNothing(removeNullable(lambda_result.type))) { auto result_type = std::make_shared(); lambda_result.column = result_type->createColumnConst(lambda_result.column->size(), 0); } /// If result column is Nullable(UInt8), then extract nested column and write 0 in all rows /// when we have NULL. else if (lambda_result.column->isNullable()) { auto result_column = IColumn::mutate(std::move(lambda_result.column)); if (isColumnConst(*result_column)) { UInt8 value = result_column->empty() ? 0 : result_column->getBool(0); auto result_type = std::make_shared(); lambda_result.column = result_type->createColumnConst(result_column->size(), value); } else { auto * column_nullable = assert_cast(result_column.get()); auto & null_map = column_nullable->getNullMapData(); auto nested_column = IColumn::mutate(std::move(column_nullable->getNestedColumnPtr())); auto & nested_data = assert_cast(nested_column.get())->getData(); for (size_t i = 0; i != nested_data.size(); ++i) { if (null_map[i]) nested_data[i] = 0; } lambda_result.column = std::move(nested_column); } } } return Impl::execute(*column_first_array, lambda_result.column); } } }; }