ClickHouse/src/Interpreters/InterpreterCreateFunctionQuery.cpp

104 lines
3.6 KiB
C++
Raw Normal View History

#include <Interpreters/InterpreterCreateFunctionQuery.h>
2021-07-19 23:34:04 +00:00
#include <Access/ContextAccess.h>
2021-08-23 14:31:58 +00:00
#include <Parsers/ASTCreateFunctionQuery.h>
#include <Parsers/ASTIdentifier.h>
2021-07-19 23:34:04 +00:00
#include <Interpreters/Context.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/ExpressionAnalyzer.h>
#include <Interpreters/FunctionNameNormalizer.h>
#include <Interpreters/UserDefinedSQLObjectsLoader.h>
#include <Interpreters/UserDefinedSQLFunctionFactory.h>
#include <Interpreters/executeDDLQueryOnCluster.h>
2021-08-23 14:31:58 +00:00
namespace DB
{
namespace ErrorCodes
{
extern const int CANNOT_CREATE_RECURSIVE_FUNCTION;
extern const int UNSUPPORTED_METHOD;
}
BlockIO InterpreterCreateFunctionQuery::execute()
{
FunctionNameNormalizer().visit(query_ptr.get());
2021-08-18 09:29:52 +00:00
auto * create_function_query = query_ptr->as<ASTCreateFunctionQuery>();
if (!create_function_query)
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Expected CREATE FUNCTION query");
2021-08-18 09:29:52 +00:00
AccessRightsElements access_rights_elements;
access_rights_elements.emplace_back(AccessType::CREATE_FUNCTION);
if (!create_function_query->cluster.empty())
return executeDDLQueryOnCluster(query_ptr, getContext(), access_rights_elements);
auto current_context = getContext();
current_context->checkAccess(access_rights_elements);
auto & user_defined_function_factory = UserDefinedSQLFunctionFactory::instance();
2021-08-18 09:29:52 +00:00
auto & function_name = create_function_query->function_name;
bool if_not_exists = create_function_query->if_not_exists;
bool replace = create_function_query->or_replace;
create_function_query->if_not_exists = false;
create_function_query->or_replace = false;
if (if_not_exists && user_defined_function_factory.tryGet(function_name) != nullptr)
return {};
2021-08-18 09:29:52 +00:00
validateFunction(create_function_query->function_core, function_name);
user_defined_function_factory.registerFunction(function_name, query_ptr, replace);
if (persist_function)
{
try
{
UserDefinedSQLObjectsLoader::instance().storeObject(current_context, UserDefinedSQLObjectType::Function, function_name, *query_ptr, replace);
}
2021-08-23 14:31:58 +00:00
catch (Exception & exception)
{
user_defined_function_factory.unregisterFunction(function_name);
2021-08-23 14:31:58 +00:00
exception.addMessage(fmt::format("while storing user defined function {} on disk", backQuote(function_name)));
throw;
}
}
2021-08-18 09:29:52 +00:00
return {};
}
void InterpreterCreateFunctionQuery::validateFunction(ASTPtr function, const String & name)
{
const auto * args_tuple = function->as<ASTFunction>()->arguments->children.at(0)->as<ASTFunction>();
std::unordered_set<String> arguments;
for (const auto & argument : args_tuple->arguments->children)
{
const auto & argument_name = argument->as<ASTIdentifier>()->name();
auto [_, inserted] = arguments.insert(argument_name);
if (!inserted)
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Identifier {} already used as function parameter", argument_name);
}
ASTPtr function_body = function->as<ASTFunction>()->children.at(0)->children.at(1);
validateFunctionRecursiveness(function_body, name);
}
void InterpreterCreateFunctionQuery::validateFunctionRecursiveness(ASTPtr node, const String & function_to_create)
{
for (const auto & child : node->children)
{
auto function_name_opt = tryGetFunctionName(child);
if (function_name_opt && function_name_opt.value() == function_to_create)
2021-08-18 09:29:52 +00:00
throw Exception(ErrorCodes::CANNOT_CREATE_RECURSIVE_FUNCTION, "You cannot create recursive function");
validateFunctionRecursiveness(child, function_to_create);
}
}
}