#include "UserDefinedSQLFunctionVisitor.h" #include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int UNSUPPORTED_METHOD; } void UserDefinedSQLFunctionMatcher::visit(ASTPtr & ast, Data &) { auto * function = ast->as(); if (!function) return; std::unordered_set udf_in_replace_process; auto replace_result = tryToReplaceFunction(*function, udf_in_replace_process); if (replace_result) ast = replace_result; } bool UserDefinedSQLFunctionMatcher::needChildVisit(const ASTPtr &, const ASTPtr &) { return true; } ASTPtr UserDefinedSQLFunctionMatcher::tryToReplaceFunction(const ASTFunction & function, std::unordered_set & udf_in_replace_process) { if (udf_in_replace_process.find(function.name) != udf_in_replace_process.end()) throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Recursive function call detected during function call {}", function.name); auto user_defined_function = UserDefinedSQLFunctionFactory::instance().tryGet(function.name); if (!user_defined_function) return nullptr; const auto & function_arguments_list = function.children.at(0)->as(); auto & function_arguments = function_arguments_list->children; const auto & create_function_query = user_defined_function->as(); auto & function_core_expression = create_function_query->function_core->children.at(0); const auto & identifiers_expression_list = function_core_expression->children.at(0)->children.at(0)->as(); const auto & identifiers_raw = identifiers_expression_list->children; if (function_arguments.size() != identifiers_raw.size()) throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Function {} expects {} arguments actual arguments {}", create_function_query->getFunctionName(), identifiers_raw.size(), function_arguments.size()); std::unordered_map identifier_name_to_function_argument; for (size_t parameter_index = 0; parameter_index < identifiers_raw.size(); ++parameter_index) { const auto & identifier = identifiers_raw[parameter_index]->as(); const auto & function_argument = function_arguments[parameter_index]; const auto & identifier_name = identifier->name(); identifier_name_to_function_argument.emplace(identifier_name, function_argument); } auto [it, _] = udf_in_replace_process.emplace(function.name); auto function_body_to_update = function_core_expression->children.at(1)->clone(); auto expression_list = std::make_shared(); expression_list->children.emplace_back(std::move(function_body_to_update)); std::stack ast_nodes_to_update; ast_nodes_to_update.push(expression_list); while (!ast_nodes_to_update.empty()) { auto ast_node_to_update = ast_nodes_to_update.top(); ast_nodes_to_update.pop(); for (auto & child : ast_node_to_update->children) { if (auto * inner_function = child->as()) { auto replace_result = tryToReplaceFunction(*inner_function, udf_in_replace_process); if (replace_result) child = replace_result; } auto identifier_name_opt = tryGetIdentifierName(child); if (identifier_name_opt) { auto function_argument_it = identifier_name_to_function_argument.find(*identifier_name_opt); if (function_argument_it == identifier_name_to_function_argument.end()) continue; auto child_alias = child->tryGetAlias(); child = function_argument_it->second->clone(); if (!child_alias.empty()) child->setAlias(child_alias); continue; } ast_nodes_to_update.push(child); } } udf_in_replace_process.erase(it); function_body_to_update = expression_list->children[0]; auto function_alias = function.tryGetAlias(); if (!function_alias.empty()) function_body_to_update->setAlias(function_alias); return function_body_to_update; } }