diff --git a/src/Interpreters/InterpreterCreateFunctionQuery.cpp b/src/Interpreters/InterpreterCreateFunctionQuery.cpp index ccb5f4040ec..39fec4a941c 100644 --- a/src/Interpreters/InterpreterCreateFunctionQuery.cpp +++ b/src/Interpreters/InterpreterCreateFunctionQuery.cpp @@ -31,20 +31,32 @@ BlockIO InterpreterCreateFunctionQuery::execute() if (!create_function_query) throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Expected CREATE FUNCTION query"); + auto & user_defined_function_factory = UserDefinedSQLFunctionFactory::instance(); + auto & function_name = create_function_query->function_name; + + bool if_not_exists = create_function_query->if_not_exists; + bool replace = create_function_query->or_replace; + + create_function_query->if_not_exists = false; + create_function_query->or_replace = false; + + if (if_not_exists && user_defined_function_factory.tryGet(function_name) != nullptr) + return {}; + validateFunction(create_function_query->function_core, function_name); - UserDefinedSQLFunctionFactory::instance().registerFunction(function_name, query_ptr); + user_defined_function_factory.registerFunction(function_name, query_ptr, replace); - if (!persist_function) + if (persist_function) { try { - UserDefinedSQLObjectsLoader::instance().storeObject(current_context, UserDefinedSQLObjectType::Function, function_name, *query_ptr); + UserDefinedSQLObjectsLoader::instance().storeObject(current_context, UserDefinedSQLObjectType::Function, function_name, *query_ptr, replace); } catch (Exception & exception) { - UserDefinedSQLFunctionFactory::instance().unregisterFunction(function_name); + user_defined_function_factory.unregisterFunction(function_name); exception.addMessage(fmt::format("while storing user defined function {} on disk", backQuote(function_name))); throw; } diff --git a/src/Interpreters/InterpreterFactory.cpp b/src/Interpreters/InterpreterFactory.cpp index 54307ae848b..fcf5f19aef6 100644 --- a/src/Interpreters/InterpreterFactory.cpp +++ b/src/Interpreters/InterpreterFactory.cpp @@ -278,7 +278,7 @@ std::unique_ptr InterpreterFactory::get(ASTPtr & query, ContextMut } else if (query->as()) { - return std::make_unique(query, context, false /*is_internal*/); + return std::make_unique(query, context, true /*persist_function*/); } else if (query->as()) { diff --git a/src/Interpreters/UserDefinedSQLFunctionFactory.cpp b/src/Interpreters/UserDefinedSQLFunctionFactory.cpp index 1d2a80305c6..f036741ca21 100644 --- a/src/Interpreters/UserDefinedSQLFunctionFactory.cpp +++ b/src/Interpreters/UserDefinedSQLFunctionFactory.cpp @@ -19,7 +19,7 @@ UserDefinedSQLFunctionFactory & UserDefinedSQLFunctionFactory::instance() return result; } -void UserDefinedSQLFunctionFactory::registerFunction(const String & function_name, ASTPtr create_function_query) +void UserDefinedSQLFunctionFactory::registerFunction(const String & function_name, ASTPtr create_function_query, bool replace) { if (FunctionFactory::instance().hasNameOrAlias(function_name)) throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "The function '{}' already exists", function_name); @@ -29,11 +29,17 @@ void UserDefinedSQLFunctionFactory::registerFunction(const String & function_nam std::lock_guard lock(mutex); - auto [_, inserted] = function_name_to_create_query.emplace(function_name, std::move(create_function_query)); + auto [it, inserted] = function_name_to_create_query.emplace(function_name, create_function_query); + if (!inserted) - throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, - "The function name '{}' is not unique", - function_name); + { + if (replace) + it->second = std::move(create_function_query); + else + throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, + "The function name '{}' is not unique", + function_name); + } } void UserDefinedSQLFunctionFactory::unregisterFunction(const String & function_name) diff --git a/src/Interpreters/UserDefinedSQLFunctionFactory.h b/src/Interpreters/UserDefinedSQLFunctionFactory.h index 6838c2f9892..6487b951705 100644 --- a/src/Interpreters/UserDefinedSQLFunctionFactory.h +++ b/src/Interpreters/UserDefinedSQLFunctionFactory.h @@ -10,21 +10,31 @@ namespace DB { +/// Factory for SQLUserDefinedFunctions class UserDefinedSQLFunctionFactory : public IHints<1, UserDefinedSQLFunctionFactory> { public: static UserDefinedSQLFunctionFactory & instance(); - void registerFunction(const String & function_name, ASTPtr create_function_query); + /** Register function for function_name in factory for specified create_function_query. + * If replace = true and function with function_name already exists replace it with create_function_query. + * Otherwise throws exception. + */ + void registerFunction(const String & function_name, ASTPtr create_function_query, bool replace); + /// Unregister function for function_name void unregisterFunction(const String & function_name); + /// Get function create query for function_name. If no function registered with function_name throws exception. ASTPtr get(const String & function_name) const; + /// Get function create query for function_name. If no function registered with function_name return nullptr. ASTPtr tryGet(const String & function_name) const; + /// Check if function with function_name registered. bool has(const String & function_name) const; + /// Get all user defined functions registered names. std::vector getAllRegisteredNames() const override; private: diff --git a/src/Interpreters/UserDefinedSQLObjectsLoader.cpp b/src/Interpreters/UserDefinedSQLObjectsLoader.cpp index e4eb97f3002..a71f1f0799c 100644 --- a/src/Interpreters/UserDefinedSQLObjectsLoader.cpp +++ b/src/Interpreters/UserDefinedSQLObjectsLoader.cpp @@ -69,7 +69,7 @@ void UserDefinedSQLObjectsLoader::loadUserDefinedObject(ContextPtr context, User 0, context->getSettingsRef().max_parser_depth); - InterpreterCreateFunctionQuery interpreter(ast, context, true /*is internal*/); + InterpreterCreateFunctionQuery interpreter(ast, context, false /*persist_function*/); interpreter.execute(); } } @@ -111,7 +111,7 @@ void UserDefinedSQLObjectsLoader::loadObjects(ContextPtr context) } } -void UserDefinedSQLObjectsLoader::storeObject(ContextPtr context, UserDefinedSQLObjectType object_type, const String & object_name, const IAST & ast) +void UserDefinedSQLObjectsLoader::storeObject(ContextPtr context, UserDefinedSQLObjectType object_type, const String & object_name, const IAST & ast, bool replace) { if (unlikely(!enable_persistence)) return; @@ -127,7 +127,7 @@ void UserDefinedSQLObjectsLoader::storeObject(ContextPtr context, UserDefinedSQL } } - if (std::filesystem::exists(file_path)) + if (!replace && std::filesystem::exists(file_path)) throw Exception(ErrorCodes::OBJECT_ALREADY_STORED_ON_DISK, "User defined object {} already stored on disk", backQuote(file_path)); LOG_DEBUG(log, "Storing object {} to file {}", backQuote(object_name), file_path); @@ -135,9 +135,9 @@ void UserDefinedSQLObjectsLoader::storeObject(ContextPtr context, UserDefinedSQL WriteBufferFromOwnString create_statement_buf; formatAST(ast, create_statement_buf, false); writeChar('\n', create_statement_buf); - String create_statement = create_statement_buf.str(); - WriteBufferFromFile out(file_path, create_statement.size(), O_WRONLY | O_CREAT | O_EXCL); + + WriteBufferFromFile out(file_path, create_statement.size()); writeString(create_statement, out); out.next(); if (context->getSettingsRef().fsync_metadata) diff --git a/src/Interpreters/UserDefinedSQLObjectsLoader.h b/src/Interpreters/UserDefinedSQLObjectsLoader.h index 17493933f21..2e747f67a8d 100644 --- a/src/Interpreters/UserDefinedSQLObjectsLoader.h +++ b/src/Interpreters/UserDefinedSQLObjectsLoader.h @@ -21,7 +21,7 @@ public: UserDefinedSQLObjectsLoader(); void loadObjects(ContextPtr context); - void storeObject(ContextPtr context, UserDefinedSQLObjectType object_type, const String & object_name, const IAST & ast); + void storeObject(ContextPtr context, UserDefinedSQLObjectType object_type, const String & object_name, const IAST & ast, bool replace); void removeObject(ContextPtr context, UserDefinedSQLObjectType object_type, const String & object_name); /// For ClickHouse local if path is not set we can disable loader. diff --git a/src/Parsers/ASTCreateFunctionQuery.cpp b/src/Parsers/ASTCreateFunctionQuery.cpp index 0b3991ddc44..4e1e7de660d 100644 --- a/src/Parsers/ASTCreateFunctionQuery.cpp +++ b/src/Parsers/ASTCreateFunctionQuery.cpp @@ -12,7 +12,18 @@ ASTPtr ASTCreateFunctionQuery::clone() const void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, IAST::FormatState & state, IAST::FormatStateStacked frame) const { - settings.ostr << (settings.hilite ? hilite_keyword : "") << "CREATE FUNCTION " << (settings.hilite ? hilite_none : ""); + settings.ostr << (settings.hilite ? hilite_keyword : "") << "CREATE "; + + if (or_replace) + settings.ostr << "OR REPLACE "; + + settings.ostr << "FUNCTION "; + + if (if_not_exists) + settings.ostr << "IF NOT EXISTS "; + + settings.ostr << (settings.hilite ? hilite_none : ""); + settings.ostr << (settings.hilite ? hilite_identifier : "") << backQuoteIfNeed(function_name) << (settings.hilite ? hilite_none : ""); settings.ostr << (settings.hilite ? hilite_keyword : "") << " AS " << (settings.hilite ? hilite_none : ""); function_core->formatImpl(settings, state, frame); diff --git a/src/Parsers/ASTCreateFunctionQuery.h b/src/Parsers/ASTCreateFunctionQuery.h index 3adddad8fbd..a58fe64c435 100644 --- a/src/Parsers/ASTCreateFunctionQuery.h +++ b/src/Parsers/ASTCreateFunctionQuery.h @@ -12,6 +12,9 @@ public: String function_name; ASTPtr function_core; + bool or_replace = false; + bool if_not_exists = false; + String getID(char) const override { return "CreateFunctionQuery"; } ASTPtr clone() const override; diff --git a/src/Parsers/ParserCreateFunctionQuery.cpp b/src/Parsers/ParserCreateFunctionQuery.cpp index fbfd02415e7..5d84b6bc2dc 100644 --- a/src/Parsers/ParserCreateFunctionQuery.cpp +++ b/src/Parsers/ParserCreateFunctionQuery.cpp @@ -1,10 +1,12 @@ +#include + #include #include #include #include #include #include -#include + namespace DB { @@ -13,6 +15,8 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp { ParserKeyword s_create("CREATE"); ParserKeyword s_function("FUNCTION"); + ParserKeyword s_or_replace("OR REPLACE"); + ParserKeyword s_if_not_exists("IF NOT EXISTS"); ParserIdentifier function_name_p; ParserKeyword s_as("AS"); ParserLambdaExpression lambda_p; @@ -20,12 +24,21 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp ASTPtr function_name; ASTPtr function_core; + bool or_replace = false; + bool if_not_exists = false; + if (!s_create.ignore(pos, expected)) return false; + if (s_or_replace.ignore(pos, expected)) + or_replace = true; + if (!s_function.ignore(pos, expected)) return false; + if (!or_replace && s_if_not_exists.ignore(pos, expected)) + if_not_exists = true; + if (!function_name_p.parse(pos, function_name, expected)) return false; @@ -40,6 +53,8 @@ bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Exp create_function_query->function_name = function_name->as().name(); create_function_query->function_core = function_core; + create_function_query->or_replace = or_replace; + create_function_query->if_not_exists = if_not_exists; return true; } diff --git a/tests/queries/0_stateless/02101_sql_user_defined_functions_create_or_replace.reference b/tests/queries/0_stateless/02101_sql_user_defined_functions_create_or_replace.reference new file mode 100644 index 00000000000..437cc81afba --- /dev/null +++ b/tests/queries/0_stateless/02101_sql_user_defined_functions_create_or_replace.reference @@ -0,0 +1,4 @@ +CREATE FUNCTION `02101_test_function` AS x -> (x + 1) +2 +CREATE FUNCTION `02101_test_function` AS x -> (x + 2) +3 diff --git a/tests/queries/0_stateless/02101_sql_user_defined_functions_create_or_replace.sql b/tests/queries/0_stateless/02101_sql_user_defined_functions_create_or_replace.sql new file mode 100644 index 00000000000..7b0ad311bd4 --- /dev/null +++ b/tests/queries/0_stateless/02101_sql_user_defined_functions_create_or_replace.sql @@ -0,0 +1,13 @@ +-- Tags: no-parallel + +CREATE OR REPLACE FUNCTION 02101_test_function AS x -> x + 1; + +SELECT create_query FROM system.functions WHERE name = '02101_test_function'; +SELECT 02101_test_function(1); + +CREATE OR REPLACE FUNCTION 02101_test_function AS x -> x + 2; + +SELECT create_query FROM system.functions WHERE name = '02101_test_function'; +SELECT 02101_test_function(1); + +DROP FUNCTION 02101_test_function; diff --git a/tests/queries/0_stateless/02102_sql_user_defined_functions_create_if_not_exists.reference b/tests/queries/0_stateless/02102_sql_user_defined_functions_create_if_not_exists.reference new file mode 100644 index 00000000000..0cfbf08886f --- /dev/null +++ b/tests/queries/0_stateless/02102_sql_user_defined_functions_create_if_not_exists.reference @@ -0,0 +1 @@ +2 diff --git a/tests/queries/0_stateless/02102_sql_user_defined_functions_create_if_not_exists.sql b/tests/queries/0_stateless/02102_sql_user_defined_functions_create_if_not_exists.sql new file mode 100644 index 00000000000..092fa660cb0 --- /dev/null +++ b/tests/queries/0_stateless/02102_sql_user_defined_functions_create_if_not_exists.sql @@ -0,0 +1,8 @@ +-- Tags: no-parallel + +CREATE FUNCTION IF NOT EXISTS 02102_test_function AS x -> x + 1; +SELECT 02102_test_function(1); + +CREATE FUNCTION 02102_test_function AS x -> x + 1; --{serverError 609} +CREATE FUNCTION IF NOT EXISTS 02102_test_function AS x -> x + 1; +DROP FUNCTION 02102_test_function;