#include #include #include #include #include #include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int BAD_ARGUMENTS; } /** Applies an aggregate function to array and returns its result. * If aggregate function has multiple arguments, then this function can be applied to multiple arrays of the same size. * * arrayReduce('agg', arr1, ...) - apply the aggregate function `agg` to arrays `arr1...` * If multiple arrays passed, then elements on corresponding positions are passed as multiple arguments to the aggregate function. */ class FunctionArrayReduce : public IFunction { public: static constexpr auto name = "arrayReduce"; static FunctionPtr create(ContextPtr) { return std::make_shared(); } String getName() const override { return name; } bool isVariadic() const override { return true; } size_t getNumberOfArguments() const override { return 0; } bool useDefaultImplementationForConstants() const override { return true; } ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {0}; } DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override; ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override; private: /// lazy initialization in getReturnTypeImpl /// TODO: init in OverloadResolver mutable AggregateFunctionPtr aggregate_function; }; DataTypePtr FunctionArrayReduce::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const { /// The first argument is a constant string with the name of the aggregate function /// (possibly with parameters in parentheses, for example: "quantile(0.99)"). if (arguments.size() < 2) throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size()) + ", should be at least 2.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); const ColumnConst * aggregate_function_name_column = checkAndGetColumnConst(arguments[0].column.get()); if (!aggregate_function_name_column) throw Exception("First argument for function " + getName() + " must be constant string: name of aggregate function.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); DataTypes argument_types(arguments.size() - 1); for (size_t i = 1, size = arguments.size(); i < size; ++i) { const DataTypeArray * arg = checkAndGetDataType(arguments[i].type.get()); if (!arg) throw Exception("Argument " + toString(i) + " for function " + getName() + " must be an array but it has type " + arguments[i].type->getName() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); argument_types[i - 1] = arg->getNestedType(); } if (!aggregate_function) { String aggregate_function_name_with_params = aggregate_function_name_column->getValue(); if (aggregate_function_name_with_params.empty()) throw Exception("First argument for function " + getName() + " (name of aggregate function) cannot be empty.", ErrorCodes::BAD_ARGUMENTS); String aggregate_function_name; Array params_row; getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params, aggregate_function_name, params_row, "function " + getName()); AggregateFunctionProperties properties; aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties); } return aggregate_function->getReturnType(); } ColumnPtr FunctionArrayReduce::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const { IAggregateFunction & agg_func = *aggregate_function; std::unique_ptr arena = std::make_unique(); /// Aggregate functions do not support constant columns. Therefore, we materialize them. std::vector materialized_columns; const size_t num_arguments_columns = arguments.size() - 1; std::vector aggregate_arguments_vec(num_arguments_columns); const ColumnArray::Offsets * offsets = nullptr; for (size_t i = 0; i < num_arguments_columns; ++i) { const IColumn * col = arguments[i + 1].column.get(); const ColumnArray::Offsets * offsets_i = nullptr; if (const ColumnArray * arr = checkAndGetColumn(col)) { aggregate_arguments_vec[i] = &arr->getData(); offsets_i = &arr->getOffsets(); } else if (const ColumnConst * const_arr = checkAndGetColumnConst(col)) { materialized_columns.emplace_back(const_arr->convertToFullColumn()); const auto & materialized_arr = typeid_cast(*materialized_columns.back()); aggregate_arguments_vec[i] = &materialized_arr.getData(); offsets_i = &materialized_arr.getOffsets(); } else throw Exception("Illegal column " + col->getName() + " as argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN); if (i == 0) offsets = offsets_i; else if (*offsets_i != *offsets) throw Exception("Lengths of all arrays passed to " + getName() + " must be equal.", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); } const IColumn ** aggregate_arguments = aggregate_arguments_vec.data(); MutableColumnPtr result_holder = result_type->createColumn(); IColumn & res_col = *result_holder; /// AggregateFunction's states should be inserted into column using specific way auto * res_col_aggregate_function = typeid_cast(&res_col); if (!res_col_aggregate_function && agg_func.isState()) throw Exception("State function " + agg_func.getName() + " inserts results into non-state column " + result_type->getName(), ErrorCodes::ILLEGAL_COLUMN); PODArray places(input_rows_count); for (size_t i = 0; i < input_rows_count; ++i) { places[i] = arena->alignedAlloc(agg_func.sizeOfData(), agg_func.alignOfData()); try { agg_func.create(places[i]); } catch (...) { for (size_t j = 0; j < i; ++j) agg_func.destroy(places[j]); throw; } } SCOPE_EXIT_MEMORY_SAFE({ for (size_t i = 0; i < input_rows_count; ++i) agg_func.destroy(places[i]); }); { auto * that = &agg_func; /// Unnest consecutive trailing -State combinators while (auto * func = typeid_cast(that)) that = func->getNestedFunction().get(); that->addBatchArray(input_rows_count, places.data(), 0, aggregate_arguments, offsets->data(), arena.get()); } for (size_t i = 0; i < input_rows_count; ++i) if (!res_col_aggregate_function) agg_func.insertResultInto(places[i], res_col, arena.get()); else res_col_aggregate_function->insertFrom(places[i]); return result_holder; } void registerFunctionArrayReduce(FunctionFactory & factory) { factory.registerFunction(); } }