Updated UserDefinedSQLFunctionFactory

This commit is contained in:
Maksim Kita 2021-10-27 18:49:18 +03:00
parent d523b28459
commit d6c0cde173
8 changed files with 105 additions and 60 deletions

View File

@ -577,7 +577,7 @@
M(607, BACKUP_ELEMENT_DUPLICATE) \
M(608, CANNOT_RESTORE_TABLE) \
M(609, FUNCTION_ALREADY_EXISTS) \
M(610, CANNOT_DROP_SYSTEM_FUNCTION) \
M(610, CANNOT_DROP_FUNCTION) \
M(611, CANNOT_CREATE_RECURSIVE_FUNCTION) \
M(612, OBJECT_ALREADY_STORED_ON_DISK) \
M(613, OBJECT_WAS_NOT_STORED_ON_DISK) \

View File

@ -24,15 +24,15 @@ namespace ErrorCodes
BlockIO InterpreterCreateFunctionQuery::execute()
{
FunctionNameNormalizer().visit(query_ptr.get());
auto * create_function_query = query_ptr->as<ASTCreateFunctionQuery>();
if (!create_function_query)
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Expected CREATE FUNCTION query");
ASTCreateFunctionQuery & create_function_query = query_ptr->as<ASTCreateFunctionQuery &>();
AccessRightsElements access_rights_elements;
access_rights_elements.emplace_back(AccessType::CREATE_FUNCTION);
if (!create_function_query->cluster.empty())
if (create_function_query.or_replace)
access_rights_elements.emplace_back(AccessType::DROP_FUNCTION);
if (!create_function_query.cluster.empty())
return executeDDLQueryOnCluster(query_ptr, getContext(), access_rights_elements);
auto current_context = getContext();
@ -40,34 +40,16 @@ BlockIO InterpreterCreateFunctionQuery::execute()
auto & user_defined_function_factory = UserDefinedSQLFunctionFactory::instance();
auto & function_name = create_function_query->function_name;
auto & function_name = create_function_query.function_name;
bool if_not_exists = create_function_query->if_not_exists;
bool replace = create_function_query->or_replace;
bool if_not_exists = create_function_query.if_not_exists;
bool replace = create_function_query.or_replace;
create_function_query->if_not_exists = false;
create_function_query->or_replace = false;
create_function_query.if_not_exists = false;
create_function_query.or_replace = false;
if (if_not_exists && user_defined_function_factory.tryGet(function_name) != nullptr)
return {};
validateFunction(create_function_query->function_core, function_name);
user_defined_function_factory.registerFunction(function_name, query_ptr, replace);
if (persist_function)
{
try
{
UserDefinedSQLObjectsLoader::instance().storeObject(current_context, UserDefinedSQLObjectType::Function, function_name, *query_ptr, replace);
}
catch (Exception & exception)
{
user_defined_function_factory.unregisterFunction(function_name);
exception.addMessage(fmt::format("while storing user defined function {} on disk", backQuote(function_name)));
throw;
}
}
validateFunction(create_function_query.function_core, function_name);
user_defined_function_factory.registerFunction(current_context, function_name, query_ptr, replace, if_not_exists, persist_function);
return {};
}

View File

@ -12,35 +12,21 @@
namespace DB
{
namespace ErrorCodes
{
extern const int UNSUPPORTED_METHOD;
}
BlockIO InterpreterDropFunctionQuery::execute()
{
FunctionNameNormalizer().visit(query_ptr.get());
auto * drop_function_query = query_ptr->as<ASTDropFunctionQuery>();
if (!drop_function_query)
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Expected DROP FUNCTION query");
ASTDropFunctionQuery & drop_function_query = query_ptr->as<ASTDropFunctionQuery &>();
AccessRightsElements access_rights_elements;
access_rights_elements.emplace_back(AccessType::DROP_FUNCTION);
if (!drop_function_query->cluster.empty())
if (!drop_function_query.cluster.empty())
return executeDDLQueryOnCluster(query_ptr, getContext(), access_rights_elements);
auto current_context = getContext();
current_context->checkAccess(access_rights_elements);
auto & user_defined_functions_factory = UserDefinedSQLFunctionFactory::instance();
if (drop_function_query->if_exists && !user_defined_functions_factory.has(drop_function_query->function_name))
return {};
UserDefinedSQLFunctionFactory::instance().unregisterFunction(drop_function_query->function_name);
UserDefinedSQLObjectsLoader::instance().removeObject(current_context, UserDefinedSQLObjectType::Function, drop_function_query->function_name);
UserDefinedSQLFunctionFactory::instance().unregisterFunction(current_context, drop_function_query.function_name, drop_function_query.if_exists);
return {};
}

View File

@ -206,6 +206,15 @@ FunctionOverloadResolverPtr UserDefinedExecutableFunctionFactory::tryGet(const S
return nullptr;
}
bool UserDefinedExecutableFunctionFactory::has(const String & function_name, ContextPtr context)
{
const auto & loader = context->getExternalUserDefinedExecutableFunctionsLoader();
auto load_result = loader.getLoadResult(function_name);
bool result = load_result.object != nullptr;
return result;
}
std::vector<String> UserDefinedExecutableFunctionFactory::getRegisteredNames(ContextPtr context)
{
const auto & loader = context->getExternalUserDefinedExecutableFunctionsLoader();

View File

@ -24,6 +24,8 @@ public:
static FunctionOverloadResolverPtr tryGet(const String & function_name, ContextPtr context);
static bool has(const String & function_name, ContextPtr context);
static std::vector<String> getRegisteredNames(ContextPtr context);
};

View File

@ -1,7 +1,13 @@
#include "UserDefinedSQLFunctionFactory.h"
#include <Common/quoteString.h>
#include <Functions/FunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Interpreters/UserDefinedSQLObjectsLoader.h>
#include <Interpreters/UserDefinedExecutableFunctionFactory.h>
#include <Interpreters/Context.h>
namespace DB
{
@ -10,7 +16,7 @@ namespace ErrorCodes
{
extern const int FUNCTION_ALREADY_EXISTS;
extern const int UNKNOWN_FUNCTION;
extern const int CANNOT_DROP_SYSTEM_FUNCTION;
extern const int CANNOT_DROP_FUNCTION;
}
UserDefinedSQLFunctionFactory & UserDefinedSQLFunctionFactory::instance()
@ -19,13 +25,31 @@ UserDefinedSQLFunctionFactory & UserDefinedSQLFunctionFactory::instance()
return result;
}
void UserDefinedSQLFunctionFactory::registerFunction(const String & function_name, ASTPtr create_function_query, bool replace)
void UserDefinedSQLFunctionFactory::registerFunction(ContextPtr context, const String & function_name, ASTPtr create_function_query, bool replace, bool if_not_exists, bool persist)
{
if (FunctionFactory::instance().hasNameOrAlias(function_name))
{
if (if_not_exists)
return;
throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "The function '{}' already exists", function_name);
}
if (AggregateFunctionFactory::instance().hasNameOrAlias(function_name))
{
if (if_not_exists)
return;
throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "The aggregate function '{}' already exists", function_name);
}
if (UserDefinedExecutableFunctionFactory::instance().has(function_name, context))
{
if (if_not_exists)
return;
throw Exception(ErrorCodes::CANNOT_DROP_FUNCTION, "User defined executable function '{}'", function_name);
}
std::lock_guard lock(mutex);
@ -33,28 +57,63 @@ void UserDefinedSQLFunctionFactory::registerFunction(const String & function_nam
if (!inserted)
{
if (if_not_exists)
return;
if (replace)
it->second = std::move(create_function_query);
it->second = create_function_query;
else
throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS,
"The function name '{}' is not unique",
function_name);
}
if (persist)
{
try
{
UserDefinedSQLObjectsLoader::instance().storeObject(context, UserDefinedSQLObjectType::Function, function_name, *create_function_query, replace);
}
catch (Exception & exception)
{
function_name_to_create_query.erase(it);
exception.addMessage(fmt::format("while storing user defined function {} on disk", backQuote(function_name)));
throw;
}
}
}
void UserDefinedSQLFunctionFactory::unregisterFunction(const String & function_name)
void UserDefinedSQLFunctionFactory::unregisterFunction(ContextPtr context, const String & function_name, bool if_exists)
{
if (FunctionFactory::instance().hasNameOrAlias(function_name) ||
AggregateFunctionFactory::instance().hasNameOrAlias(function_name))
throw Exception(ErrorCodes::CANNOT_DROP_SYSTEM_FUNCTION, "Cannot drop system function '{}'", function_name);
throw Exception(ErrorCodes::CANNOT_DROP_FUNCTION, "Cannot drop system function '{}'", function_name);
if (UserDefinedExecutableFunctionFactory::instance().has(function_name, context))
throw Exception(ErrorCodes::CANNOT_DROP_FUNCTION, "Cannot drop user defined executable function '{}'", function_name);
std::lock_guard lock(mutex);
auto it = function_name_to_create_query.find(function_name);
if (it == function_name_to_create_query.end())
{
if (if_exists)
return;
throw Exception(ErrorCodes::UNKNOWN_FUNCTION,
"The function name '{}' is not registered",
function_name);
}
try
{
UserDefinedSQLObjectsLoader::instance().removeObject(context, UserDefinedSQLObjectType::Function, function_name);
}
catch (Exception & exception)
{
exception.addMessage(fmt::format("while removing user defined function {} from disk", backQuote(function_name)));
throw;
}
function_name_to_create_query.erase(it);
}

View File

@ -6,6 +6,8 @@
#include <Common/NamePrompter.h>
#include <Parsers/ASTCreateFunctionQuery.h>
#include <Interpreters/Context_fwd.h>
namespace DB
{
@ -17,13 +19,17 @@ public:
static UserDefinedSQLFunctionFactory & instance();
/** Register function for function_name in factory for specified create_function_query.
* If replace = true and function with function_name already exists replace it with create_function_query.
* Otherwise throws exception.
* If function exists and if_not_exists = false and replace = false throws exception.
* If replace = true and sql user defined function with function_name already exists replace it with create_function_query.
* If persist = true persist function on disk.
*/
void registerFunction(const String & function_name, ASTPtr create_function_query, bool replace);
void registerFunction(ContextPtr context, const String & function_name, ASTPtr create_function_query, bool replace, bool if_not_exists, bool persist);
/// Unregister function for function_name
void unregisterFunction(const String & function_name);
/** Unregister function for function_name.
* If if_exists = true then do not throw exception if function is not registered.
* If if_exists = false then throw exception if function is not registered.
*/
void unregisterFunction(ContextPtr context, const String & function_name, bool if_exists);
/// Get function create query for function_name. If no function registered with function_name throws exception.
ASTPtr get(const String & function_name) const;

View File

@ -22,12 +22,13 @@ def test_sql_user_defined_functions_on_cluster():
assert "Unknown function test_function" in ch2.query_and_get_error("SELECT test_function(1);")
assert "Unknown function test_function" in ch3.query_and_get_error("SELECT test_function(1);")
ch1.query_with_retry("CREATE FUNCTION test_function ON CLUSTER 'cluster' AS x -> x + 1;", retry_count=5)
ch1.query_with_retry("CREATE FUNCTION test_function ON CLUSTER 'cluster' AS x -> x + 1;")
assert ch1.query("SELECT test_function(1);") == "2\n"
assert ch2.query("SELECT test_function(1);") == "2\n"
assert ch3.query("SELECT test_function(1);") == "2\n"
ch2.query_with_retry("DROP FUNCTION test_function ON CLUSTER 'cluster'", retry_count=5)
ch2.query_with_retry("DROP FUNCTION test_function ON CLUSTER 'cluster'")
assert "Unknown function test_function" in ch1.query_and_get_error("SELECT test_function(1);")
assert "Unknown function test_function" in ch2.query_and_get_error("SELECT test_function(1);")
assert "Unknown function test_function" in ch3.query_and_get_error("SELECT test_function(1);")