Updated UserDefinedSQLFunctionFactory

This commit is contained in:
Maksim Kita 2021-10-27 18:49:18 +03:00
parent d523b28459
commit d6c0cde173
8 changed files with 105 additions and 60 deletions

View File

@ -577,7 +577,7 @@
M(607, BACKUP_ELEMENT_DUPLICATE) \ M(607, BACKUP_ELEMENT_DUPLICATE) \
M(608, CANNOT_RESTORE_TABLE) \ M(608, CANNOT_RESTORE_TABLE) \
M(609, FUNCTION_ALREADY_EXISTS) \ M(609, FUNCTION_ALREADY_EXISTS) \
M(610, CANNOT_DROP_SYSTEM_FUNCTION) \ M(610, CANNOT_DROP_FUNCTION) \
M(611, CANNOT_CREATE_RECURSIVE_FUNCTION) \ M(611, CANNOT_CREATE_RECURSIVE_FUNCTION) \
M(612, OBJECT_ALREADY_STORED_ON_DISK) \ M(612, OBJECT_ALREADY_STORED_ON_DISK) \
M(613, OBJECT_WAS_NOT_STORED_ON_DISK) \ M(613, OBJECT_WAS_NOT_STORED_ON_DISK) \

View File

@ -24,15 +24,15 @@ namespace ErrorCodes
BlockIO InterpreterCreateFunctionQuery::execute() BlockIO InterpreterCreateFunctionQuery::execute()
{ {
FunctionNameNormalizer().visit(query_ptr.get()); FunctionNameNormalizer().visit(query_ptr.get());
auto * create_function_query = query_ptr->as<ASTCreateFunctionQuery>(); ASTCreateFunctionQuery & create_function_query = query_ptr->as<ASTCreateFunctionQuery &>();
if (!create_function_query)
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Expected CREATE FUNCTION query");
AccessRightsElements access_rights_elements; AccessRightsElements access_rights_elements;
access_rights_elements.emplace_back(AccessType::CREATE_FUNCTION); access_rights_elements.emplace_back(AccessType::CREATE_FUNCTION);
if (!create_function_query->cluster.empty()) if (create_function_query.or_replace)
access_rights_elements.emplace_back(AccessType::DROP_FUNCTION);
if (!create_function_query.cluster.empty())
return executeDDLQueryOnCluster(query_ptr, getContext(), access_rights_elements); return executeDDLQueryOnCluster(query_ptr, getContext(), access_rights_elements);
auto current_context = getContext(); auto current_context = getContext();
@ -40,34 +40,16 @@ BlockIO InterpreterCreateFunctionQuery::execute()
auto & user_defined_function_factory = UserDefinedSQLFunctionFactory::instance(); auto & user_defined_function_factory = UserDefinedSQLFunctionFactory::instance();
auto & function_name = create_function_query->function_name; auto & function_name = create_function_query.function_name;
bool if_not_exists = create_function_query->if_not_exists; bool if_not_exists = create_function_query.if_not_exists;
bool replace = create_function_query->or_replace; bool replace = create_function_query.or_replace;
create_function_query->if_not_exists = false; create_function_query.if_not_exists = false;
create_function_query->or_replace = false; create_function_query.or_replace = false;
if (if_not_exists && user_defined_function_factory.tryGet(function_name) != nullptr) validateFunction(create_function_query.function_core, function_name);
return {}; user_defined_function_factory.registerFunction(current_context, function_name, query_ptr, replace, if_not_exists, persist_function);
validateFunction(create_function_query->function_core, function_name);
user_defined_function_factory.registerFunction(function_name, query_ptr, replace);
if (persist_function)
{
try
{
UserDefinedSQLObjectsLoader::instance().storeObject(current_context, UserDefinedSQLObjectType::Function, function_name, *query_ptr, replace);
}
catch (Exception & exception)
{
user_defined_function_factory.unregisterFunction(function_name);
exception.addMessage(fmt::format("while storing user defined function {} on disk", backQuote(function_name)));
throw;
}
}
return {}; return {};
} }

View File

@ -12,35 +12,21 @@
namespace DB namespace DB
{ {
namespace ErrorCodes
{
extern const int UNSUPPORTED_METHOD;
}
BlockIO InterpreterDropFunctionQuery::execute() BlockIO InterpreterDropFunctionQuery::execute()
{ {
FunctionNameNormalizer().visit(query_ptr.get()); FunctionNameNormalizer().visit(query_ptr.get());
auto * drop_function_query = query_ptr->as<ASTDropFunctionQuery>(); ASTDropFunctionQuery & drop_function_query = query_ptr->as<ASTDropFunctionQuery &>();
if (!drop_function_query)
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Expected DROP FUNCTION query");
AccessRightsElements access_rights_elements; AccessRightsElements access_rights_elements;
access_rights_elements.emplace_back(AccessType::DROP_FUNCTION); access_rights_elements.emplace_back(AccessType::DROP_FUNCTION);
if (!drop_function_query->cluster.empty()) if (!drop_function_query.cluster.empty())
return executeDDLQueryOnCluster(query_ptr, getContext(), access_rights_elements); return executeDDLQueryOnCluster(query_ptr, getContext(), access_rights_elements);
auto current_context = getContext(); auto current_context = getContext();
current_context->checkAccess(access_rights_elements); current_context->checkAccess(access_rights_elements);
auto & user_defined_functions_factory = UserDefinedSQLFunctionFactory::instance(); UserDefinedSQLFunctionFactory::instance().unregisterFunction(current_context, drop_function_query.function_name, drop_function_query.if_exists);
if (drop_function_query->if_exists && !user_defined_functions_factory.has(drop_function_query->function_name))
return {};
UserDefinedSQLFunctionFactory::instance().unregisterFunction(drop_function_query->function_name);
UserDefinedSQLObjectsLoader::instance().removeObject(current_context, UserDefinedSQLObjectType::Function, drop_function_query->function_name);
return {}; return {};
} }

View File

@ -206,6 +206,15 @@ FunctionOverloadResolverPtr UserDefinedExecutableFunctionFactory::tryGet(const S
return nullptr; return nullptr;
} }
bool UserDefinedExecutableFunctionFactory::has(const String & function_name, ContextPtr context)
{
const auto & loader = context->getExternalUserDefinedExecutableFunctionsLoader();
auto load_result = loader.getLoadResult(function_name);
bool result = load_result.object != nullptr;
return result;
}
std::vector<String> UserDefinedExecutableFunctionFactory::getRegisteredNames(ContextPtr context) std::vector<String> UserDefinedExecutableFunctionFactory::getRegisteredNames(ContextPtr context)
{ {
const auto & loader = context->getExternalUserDefinedExecutableFunctionsLoader(); const auto & loader = context->getExternalUserDefinedExecutableFunctionsLoader();

View File

@ -24,6 +24,8 @@ public:
static FunctionOverloadResolverPtr tryGet(const String & function_name, ContextPtr context); static FunctionOverloadResolverPtr tryGet(const String & function_name, ContextPtr context);
static bool has(const String & function_name, ContextPtr context);
static std::vector<String> getRegisteredNames(ContextPtr context); static std::vector<String> getRegisteredNames(ContextPtr context);
}; };

View File

@ -1,7 +1,13 @@
#include "UserDefinedSQLFunctionFactory.h" #include "UserDefinedSQLFunctionFactory.h"
#include <Common/quoteString.h>
#include <Functions/FunctionFactory.h> #include <Functions/FunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionFactory.h> #include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Interpreters/UserDefinedSQLObjectsLoader.h>
#include <Interpreters/UserDefinedExecutableFunctionFactory.h>
#include <Interpreters/Context.h>
namespace DB namespace DB
{ {
@ -10,7 +16,7 @@ namespace ErrorCodes
{ {
extern const int FUNCTION_ALREADY_EXISTS; extern const int FUNCTION_ALREADY_EXISTS;
extern const int UNKNOWN_FUNCTION; extern const int UNKNOWN_FUNCTION;
extern const int CANNOT_DROP_SYSTEM_FUNCTION; extern const int CANNOT_DROP_FUNCTION;
} }
UserDefinedSQLFunctionFactory & UserDefinedSQLFunctionFactory::instance() UserDefinedSQLFunctionFactory & UserDefinedSQLFunctionFactory::instance()
@ -19,13 +25,31 @@ UserDefinedSQLFunctionFactory & UserDefinedSQLFunctionFactory::instance()
return result; return result;
} }
void UserDefinedSQLFunctionFactory::registerFunction(const String & function_name, ASTPtr create_function_query, bool replace) void UserDefinedSQLFunctionFactory::registerFunction(ContextPtr context, const String & function_name, ASTPtr create_function_query, bool replace, bool if_not_exists, bool persist)
{ {
if (FunctionFactory::instance().hasNameOrAlias(function_name)) if (FunctionFactory::instance().hasNameOrAlias(function_name))
{
if (if_not_exists)
return;
throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "The function '{}' already exists", function_name); throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "The function '{}' already exists", function_name);
}
if (AggregateFunctionFactory::instance().hasNameOrAlias(function_name)) if (AggregateFunctionFactory::instance().hasNameOrAlias(function_name))
{
if (if_not_exists)
return;
throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "The aggregate function '{}' already exists", function_name); throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "The aggregate function '{}' already exists", function_name);
}
if (UserDefinedExecutableFunctionFactory::instance().has(function_name, context))
{
if (if_not_exists)
return;
throw Exception(ErrorCodes::CANNOT_DROP_FUNCTION, "User defined executable function '{}'", function_name);
}
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
@ -33,28 +57,63 @@ void UserDefinedSQLFunctionFactory::registerFunction(const String & function_nam
if (!inserted) if (!inserted)
{ {
if (if_not_exists)
return;
if (replace) if (replace)
it->second = std::move(create_function_query); it->second = create_function_query;
else else
throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS,
"The function name '{}' is not unique", "The function name '{}' is not unique",
function_name); function_name);
} }
if (persist)
{
try
{
UserDefinedSQLObjectsLoader::instance().storeObject(context, UserDefinedSQLObjectType::Function, function_name, *create_function_query, replace);
}
catch (Exception & exception)
{
function_name_to_create_query.erase(it);
exception.addMessage(fmt::format("while storing user defined function {} on disk", backQuote(function_name)));
throw;
}
}
} }
void UserDefinedSQLFunctionFactory::unregisterFunction(const String & function_name) void UserDefinedSQLFunctionFactory::unregisterFunction(ContextPtr context, const String & function_name, bool if_exists)
{ {
if (FunctionFactory::instance().hasNameOrAlias(function_name) || if (FunctionFactory::instance().hasNameOrAlias(function_name) ||
AggregateFunctionFactory::instance().hasNameOrAlias(function_name)) AggregateFunctionFactory::instance().hasNameOrAlias(function_name))
throw Exception(ErrorCodes::CANNOT_DROP_SYSTEM_FUNCTION, "Cannot drop system function '{}'", function_name); throw Exception(ErrorCodes::CANNOT_DROP_FUNCTION, "Cannot drop system function '{}'", function_name);
if (UserDefinedExecutableFunctionFactory::instance().has(function_name, context))
throw Exception(ErrorCodes::CANNOT_DROP_FUNCTION, "Cannot drop user defined executable function '{}'", function_name);
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
auto it = function_name_to_create_query.find(function_name); auto it = function_name_to_create_query.find(function_name);
if (it == function_name_to_create_query.end()) if (it == function_name_to_create_query.end())
{
if (if_exists)
return;
throw Exception(ErrorCodes::UNKNOWN_FUNCTION, throw Exception(ErrorCodes::UNKNOWN_FUNCTION,
"The function name '{}' is not registered", "The function name '{}' is not registered",
function_name); function_name);
}
try
{
UserDefinedSQLObjectsLoader::instance().removeObject(context, UserDefinedSQLObjectType::Function, function_name);
}
catch (Exception & exception)
{
exception.addMessage(fmt::format("while removing user defined function {} from disk", backQuote(function_name)));
throw;
}
function_name_to_create_query.erase(it); function_name_to_create_query.erase(it);
} }

View File

@ -6,6 +6,8 @@
#include <Common/NamePrompter.h> #include <Common/NamePrompter.h>
#include <Parsers/ASTCreateFunctionQuery.h> #include <Parsers/ASTCreateFunctionQuery.h>
#include <Interpreters/Context_fwd.h>
namespace DB namespace DB
{ {
@ -17,13 +19,17 @@ public:
static UserDefinedSQLFunctionFactory & instance(); static UserDefinedSQLFunctionFactory & instance();
/** Register function for function_name in factory for specified create_function_query. /** Register function for function_name in factory for specified create_function_query.
* If replace = true and function with function_name already exists replace it with create_function_query. * If function exists and if_not_exists = false and replace = false throws exception.
* Otherwise throws exception. * If replace = true and sql user defined function with function_name already exists replace it with create_function_query.
* If persist = true persist function on disk.
*/ */
void registerFunction(const String & function_name, ASTPtr create_function_query, bool replace); void registerFunction(ContextPtr context, const String & function_name, ASTPtr create_function_query, bool replace, bool if_not_exists, bool persist);
/// Unregister function for function_name /** Unregister function for function_name.
void unregisterFunction(const String & function_name); * If if_exists = true then do not throw exception if function is not registered.
* If if_exists = false then throw exception if function is not registered.
*/
void unregisterFunction(ContextPtr context, const String & function_name, bool if_exists);
/// Get function create query for function_name. If no function registered with function_name throws exception. /// Get function create query for function_name. If no function registered with function_name throws exception.
ASTPtr get(const String & function_name) const; ASTPtr get(const String & function_name) const;

View File

@ -22,12 +22,13 @@ def test_sql_user_defined_functions_on_cluster():
assert "Unknown function test_function" in ch2.query_and_get_error("SELECT test_function(1);") assert "Unknown function test_function" in ch2.query_and_get_error("SELECT test_function(1);")
assert "Unknown function test_function" in ch3.query_and_get_error("SELECT test_function(1);") assert "Unknown function test_function" in ch3.query_and_get_error("SELECT test_function(1);")
ch1.query_with_retry("CREATE FUNCTION test_function ON CLUSTER 'cluster' AS x -> x + 1;", retry_count=5) ch1.query_with_retry("CREATE FUNCTION test_function ON CLUSTER 'cluster' AS x -> x + 1;")
assert ch1.query("SELECT test_function(1);") == "2\n" assert ch1.query("SELECT test_function(1);") == "2\n"
assert ch2.query("SELECT test_function(1);") == "2\n" assert ch2.query("SELECT test_function(1);") == "2\n"
assert ch3.query("SELECT test_function(1);") == "2\n" assert ch3.query("SELECT test_function(1);") == "2\n"
ch2.query_with_retry("DROP FUNCTION test_function ON CLUSTER 'cluster'", retry_count=5) ch2.query_with_retry("DROP FUNCTION test_function ON CLUSTER 'cluster'")
assert "Unknown function test_function" in ch1.query_and_get_error("SELECT test_function(1);") assert "Unknown function test_function" in ch1.query_and_get_error("SELECT test_function(1);")
assert "Unknown function test_function" in ch2.query_and_get_error("SELECT test_function(1);") assert "Unknown function test_function" in ch2.query_and_get_error("SELECT test_function(1);")
assert "Unknown function test_function" in ch3.query_and_get_error("SELECT test_function(1);") assert "Unknown function test_function" in ch3.query_and_get_error("SELECT test_function(1);")