#include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int LOGICAL_ERROR; extern const int ILLEGAL_TYPE_OF_ARGUMENT; } class FunctionArrayPush : public IFunction { public: FunctionArrayPush(const Context & context, bool push_front, const char * name) : context(context), push_front(push_front), name(name) {} String getName() const override { return name; } bool isVariadic() const override { return false; } size_t getNumberOfArguments() const override { return 2; } DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { if (arguments[0]->onlyNull()) return arguments[0]; auto array_type = typeid_cast(arguments[0].get()); if (!array_type) throw Exception("First argument for function " + getName() + " must be an array but it has type " + arguments[0]->getName() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); auto nested_type = array_type->getNestedType(); DataTypes types = {nested_type, arguments[1]}; return std::make_shared(getLeastSupertype(types)); } void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override { const auto & return_type = block.getByPosition(result).type; if (return_type->onlyNull()) { block.getByPosition(result).column = return_type->createColumnConstWithDefaultValue(input_rows_count); return; } auto result_column = return_type->createColumn(); auto array_column = block.getByPosition(arguments[0]).column; auto appended_column = block.getByPosition(arguments[1]).column; if (!block.getByPosition(arguments[0]).type->equals(*return_type)) array_column = castColumn(block.getByPosition(arguments[0]), return_type, context); const DataTypePtr & return_nested_type = typeid_cast(*return_type).getNestedType(); if (!block.getByPosition(arguments[1]).type->equals(*return_nested_type)) appended_column = castColumn(block.getByPosition(arguments[1]), return_nested_type, context); std::unique_ptr array_source; std::unique_ptr value_source; size_t size = array_column->size(); bool is_const = false; if (auto const_array_column = typeid_cast(array_column.get())) { is_const = true; array_column = const_array_column->getDataColumnPtr(); } if (auto argument_column_array = typeid_cast(array_column.get())) array_source = GatherUtils::createArraySource(*argument_column_array, is_const, size); else throw Exception{"First arguments for function " + getName() + " must be array.", ErrorCodes::LOGICAL_ERROR}; bool is_appended_const = false; if (auto const_appended_column = typeid_cast(appended_column.get())) { is_appended_const = true; appended_column = const_appended_column->getDataColumnPtr(); } value_source = GatherUtils::createValueSource(*appended_column, is_appended_const, size); auto sink = GatherUtils::createArraySink(typeid_cast(*result_column), size); GatherUtils::push(*array_source, *value_source, *sink, push_front); block.getByPosition(result).column = std::move(result_column); } bool useDefaultImplementationForConstants() const override { return true; } bool useDefaultImplementationForNulls() const override { return false; } private: const Context & context; bool push_front; const char * name; }; }