#include #include #include #include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int UNKNOWN_IDENTIFIER; extern const int CANNOT_CREATE_RECURSIVE_FUNCTION; extern const int UNSUPPORTED_METHOD; } BlockIO InterpreterCreateFunctionQuery::execute() { auto current_context = getContext(); current_context->checkAccess(AccessType::CREATE_FUNCTION); FunctionNameNormalizer().visit(query_ptr.get()); auto * create_function_query = query_ptr->as(); if (!create_function_query) throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Expected CREATE FUNCTION query"); auto & function_name = create_function_query->function_name; validateFunction(create_function_query->function_core, function_name); UserDefinedSQLFunctionFactory::instance().registerFunction(function_name, query_ptr); if (!persist_function) { try { UserDefinedSQLObjectsLoader::instance().storeObject(current_context, UserDefinedSQLObjectType::Function, function_name, *query_ptr); } catch (Exception & exception) { UserDefinedSQLFunctionFactory::instance().unregisterFunction(function_name); exception.addMessage(fmt::format("while storing user defined function {} on disk", backQuote(function_name))); throw; } } return {}; } void InterpreterCreateFunctionQuery::validateFunction(ASTPtr function, const String & name) { const auto * args_tuple = function->as()->arguments->children.at(0)->as(); std::unordered_set arguments; for (const auto & argument : args_tuple->arguments->children) { const auto & argument_name = argument->as()->name(); auto [_, inserted] = arguments.insert(argument_name); if (!inserted) throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Identifier {} already used as function parameter", argument_name); } ASTPtr function_body = function->as()->children.at(0)->children.at(1); std::unordered_set identifiers_in_body = getIdentifiers(function_body); for (const auto & identifier : identifiers_in_body) { if (!arguments.contains(identifier)) throw Exception(ErrorCodes::UNKNOWN_IDENTIFIER, "Identifier {} does not exist in arguments", backQuote(identifier)); } validateFunctionRecursiveness(function_body, name); } std::unordered_set InterpreterCreateFunctionQuery::getIdentifiers(ASTPtr node) { std::unordered_set identifiers; std::stack ast_nodes_to_process; ast_nodes_to_process.push(node); while (!ast_nodes_to_process.empty()) { auto ast_node_to_process = ast_nodes_to_process.top(); ast_nodes_to_process.pop(); for (const auto & child : ast_node_to_process->children) { auto identifier_name_opt = tryGetIdentifierName(child); if (identifier_name_opt) identifiers.insert(identifier_name_opt.value()); ast_nodes_to_process.push(child); } } return identifiers; } void InterpreterCreateFunctionQuery::validateFunctionRecursiveness(ASTPtr node, const String & function_to_create) { for (const auto & child : node->children) { auto function_name_opt = tryGetFunctionName(child); if (function_name_opt && function_name_opt.value() == function_to_create) throw Exception(ErrorCodes::CANNOT_CREATE_RECURSIVE_FUNCTION, "You cannot create recursive function"); validateFunctionRecursiveness(child, function_to_create); } } }