mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 23:21:59 +00:00
Updated user defined functions implementation
This commit is contained in:
parent
6b2c249adc
commit
01682a86b3
@ -111,8 +111,8 @@ endif()
|
||||
list (APPEND clickhouse_common_io_sources ${CONFIG_BUILD})
|
||||
list (APPEND clickhouse_common_io_headers ${CONFIG_VERSION} ${CONFIG_COMMON})
|
||||
|
||||
list (APPEND dbms_sources Functions/IFunction.cpp Functions/FunctionFactory.cpp Functions/FunctionHelpers.cpp Functions/extractTimeZoneFromFunctionArguments.cpp Functions/replicate.cpp Functions/FunctionsLogical.cpp Functions/UserDefinedFunction.cpp)
|
||||
list (APPEND dbms_headers Functions/IFunction.h Functions/FunctionFactory.h Functions/FunctionHelpers.h Functions/extractTimeZoneFromFunctionArguments.h Functions/replicate.h Functions/FunctionsLogical.h Functions/UserDefinedFunction.h)
|
||||
list (APPEND dbms_sources Functions/IFunction.cpp Functions/FunctionFactory.cpp Functions/FunctionHelpers.cpp Functions/extractTimeZoneFromFunctionArguments.cpp Functions/replicate.cpp Functions/FunctionsLogical.cpp)
|
||||
list (APPEND dbms_headers Functions/IFunction.h Functions/FunctionFactory.h Functions/FunctionHelpers.h Functions/extractTimeZoneFromFunctionArguments.h Functions/replicate.h Functions/FunctionsLogical.h)
|
||||
|
||||
list (APPEND dbms_sources
|
||||
AggregateFunctions/IAggregateFunction.cpp
|
||||
|
@ -567,11 +567,11 @@
|
||||
M(596, INTERSECT_OR_EXCEPT_RESULT_STRUCTURES_MISMATCH) \
|
||||
M(597, NO_SUCH_ERROR_CODE) \
|
||||
\
|
||||
M(591, FUNCTION_ALREADY_EXISTS) \
|
||||
M(592, CANNOT_DROP_SYSTEM_FUNCTION) \
|
||||
M(593, CANNOT_CREATE_RECURSIVE_FUNCTION) \
|
||||
M(594, OBJECT_ALREADY_STORED_ON_DISK) \
|
||||
M(595, OBJECT_WAS_NOT_STORED_ON_DISK) \
|
||||
M(598, FUNCTION_ALREADY_EXISTS) \
|
||||
M(599, CANNOT_DROP_SYSTEM_FUNCTION) \
|
||||
M(600, CANNOT_CREATE_RECURSIVE_FUNCTION) \
|
||||
M(601, OBJECT_ALREADY_STORED_ON_DISK) \
|
||||
M(602, OBJECT_WAS_NOT_STORED_ON_DISK) \
|
||||
\
|
||||
M(998, POSTGRESQL_CONNECTION_FAILURE) \
|
||||
M(999, KEEPER_EXCEPTION) \
|
||||
|
@ -1,5 +1,4 @@
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/UserDefinedFunction.h>
|
||||
|
||||
#include <Interpreters/Context.h>
|
||||
|
||||
@ -36,7 +35,6 @@ void FunctionFactory::registerFunction(
|
||||
Value creator,
|
||||
CaseSensitiveness case_sensitiveness)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(mutex);
|
||||
if (!functions.emplace(name, creator).second)
|
||||
throw Exception("FunctionFactory: the function name '" + name + "' is not unique",
|
||||
ErrorCodes::LOGICAL_ERROR);
|
||||
@ -79,7 +77,6 @@ FunctionOverloadResolverPtr FunctionFactory::getImpl(
|
||||
|
||||
std::vector<std::string> FunctionFactory::getAllNames() const
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(mutex);
|
||||
std::vector<std::string> res;
|
||||
res.reserve(functions.size());
|
||||
for (const auto & func : functions)
|
||||
@ -98,7 +95,6 @@ FunctionOverloadResolverPtr FunctionFactory::tryGetImpl(
|
||||
const std::string & name_param,
|
||||
ContextPtr context) const
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(mutex);
|
||||
String name = getAliasToOrName(name_param);
|
||||
FunctionOverloadResolverPtr res;
|
||||
|
||||
@ -140,48 +136,4 @@ FunctionFactory & FunctionFactory::instance()
|
||||
return ret;
|
||||
}
|
||||
|
||||
void FunctionFactory::registerUserDefinedFunction(const ASTCreateFunctionQuery & create_function_query)
|
||||
{
|
||||
if (hasNameOrAlias(create_function_query.function_name))
|
||||
throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "The function {} already exists", create_function_query.function_name);
|
||||
|
||||
if (AggregateFunctionFactory::instance().isAggregateFunctionName(create_function_query.function_name))
|
||||
throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "The aggregate function {} already exists", create_function_query.function_name);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(mutex);
|
||||
user_defined_functions.insert(create_function_query.function_name);
|
||||
}
|
||||
registerFunction(create_function_query.function_name, [create_function_query](ContextPtr context)
|
||||
{
|
||||
auto function = UserDefinedFunction::create(context);
|
||||
function->setName(create_function_query.function_name);
|
||||
function->setFunctionCore(create_function_query.function_core);
|
||||
|
||||
FunctionOverloadResolverPtr res = std::make_unique<FunctionToOverloadResolverAdaptor>(function);
|
||||
return res;
|
||||
}, CaseSensitiveness::CaseSensitive);
|
||||
}
|
||||
|
||||
void FunctionFactory::unregisterUserDefinedFunction(const String & name)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(mutex);
|
||||
if (functions.contains(name))
|
||||
{
|
||||
if (user_defined_functions.contains(name))
|
||||
{
|
||||
functions.erase(name);
|
||||
user_defined_functions.erase(name);
|
||||
return;
|
||||
} else
|
||||
throw Exception("System functions cannot be dropped", ErrorCodes::CANNOT_DROP_SYSTEM_FUNCTION);
|
||||
}
|
||||
|
||||
auto hints = this->getHints(name);
|
||||
if (!hints.empty())
|
||||
throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "Unknown function {}. Maybe you meant: {}", name, toString(hints));
|
||||
else
|
||||
throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "Unknown function {}", name);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -70,9 +70,7 @@ private:
|
||||
using Functions = std::unordered_map<std::string, Value>;
|
||||
|
||||
Functions functions;
|
||||
std::unordered_set<String> user_defined_functions;
|
||||
Functions case_insensitive_functions;
|
||||
mutable std::mutex mutex;
|
||||
|
||||
template <typename Function>
|
||||
static FunctionOverloadResolverPtr adaptFunctionToOverloadResolver(ContextPtr context)
|
||||
|
@ -1,93 +0,0 @@
|
||||
#include <DataTypes/DataTypeFactory.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <Functions/UserDefinedFunction.h>
|
||||
#include <Interpreters/TreeRewriter.h>
|
||||
#include <Interpreters/ExpressionActions.h>
|
||||
#include <Interpreters/ExpressionAnalyzer.h>
|
||||
#include <Parsers/ASTIdentifier.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int TYPE_MISMATCH;
|
||||
}
|
||||
|
||||
UserDefinedFunction::UserDefinedFunction(ContextPtr context_)
|
||||
: function_core(nullptr)
|
||||
, context(context_)
|
||||
{}
|
||||
|
||||
UserDefinedFunctionPtr UserDefinedFunction::create(ContextPtr context)
|
||||
{
|
||||
return std::make_shared<UserDefinedFunction>(context);
|
||||
}
|
||||
|
||||
String UserDefinedFunction::getName() const
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
ColumnPtr UserDefinedFunction::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const
|
||||
{
|
||||
Block block = executeCore(arguments);
|
||||
|
||||
String result_name = function_core->as<ASTFunction>()->arguments->children.at(1)->getColumnName();
|
||||
|
||||
// result of function executing was inserted in the end
|
||||
return block.getColumns().back();
|
||||
}
|
||||
|
||||
size_t UserDefinedFunction::getNumberOfArguments() const
|
||||
{
|
||||
return function_core->as<ASTFunction>()->arguments->children[0]->size() - 2;
|
||||
}
|
||||
|
||||
void UserDefinedFunction::setName(const String & name_)
|
||||
{
|
||||
name = name_;
|
||||
}
|
||||
|
||||
void UserDefinedFunction::setFunctionCore(ASTPtr function_core_)
|
||||
{
|
||||
function_core = function_core_;
|
||||
}
|
||||
|
||||
DataTypePtr UserDefinedFunction::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
|
||||
{
|
||||
Block block = executeCore(arguments);
|
||||
return block.getDataTypes().back();
|
||||
}
|
||||
|
||||
Block UserDefinedFunction::executeCore(const ColumnsWithTypeAndName & arguments) const
|
||||
{
|
||||
const auto * lambda_args_tuple = function_core->as<ASTFunction>()->arguments->children.at(0)->as<ASTFunction>();
|
||||
const ASTs & lambda_arg_asts = lambda_args_tuple->arguments->children;
|
||||
|
||||
NamesAndTypesList lambda_arguments;
|
||||
Block block;
|
||||
|
||||
for (size_t j = 0; j < lambda_arg_asts.size(); ++j)
|
||||
{
|
||||
auto opt_arg_name = tryGetIdentifierName(lambda_arg_asts[j]);
|
||||
if (!opt_arg_name)
|
||||
throw Exception("lambda argument declarations must be identifiers", ErrorCodes::TYPE_MISMATCH);
|
||||
|
||||
lambda_arguments.emplace_back(*opt_arg_name, arguments[j].type);
|
||||
auto column_ptr = arguments[j].column;
|
||||
if (!column_ptr)
|
||||
column_ptr = arguments[j].type->createColumnConstWithDefaultValue(1);
|
||||
block.insert({column_ptr, arguments[j].type, *opt_arg_name});
|
||||
}
|
||||
|
||||
ASTPtr lambda_body = function_core->as<ASTFunction>()->children.at(0)->children.at(1);
|
||||
auto syntax_result = TreeRewriter(context).analyze(lambda_body, lambda_arguments);
|
||||
ExpressionAnalyzer analyzer(lambda_body, syntax_result, context);
|
||||
ExpressionActionsPtr actions = analyzer.getActions(false);
|
||||
|
||||
actions->execute(block);
|
||||
return block;
|
||||
}
|
||||
|
||||
}
|
@ -1,38 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/IFunction.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Parsers/ASTCreateFunctionQuery.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class UserDefinedFunction;
|
||||
using UserDefinedFunctionPtr = std::shared_ptr<UserDefinedFunction>;
|
||||
|
||||
class UserDefinedFunction : public IFunction
|
||||
{
|
||||
public:
|
||||
explicit UserDefinedFunction(ContextPtr context_);
|
||||
static UserDefinedFunctionPtr create(ContextPtr context);
|
||||
|
||||
String getName() const override;
|
||||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override;
|
||||
size_t getNumberOfArguments() const override;
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override;
|
||||
|
||||
void setName(const String & name_);
|
||||
void setFunctionCore(ASTPtr function_core_);
|
||||
|
||||
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
|
||||
|
||||
private:
|
||||
Block executeCore(const ColumnsWithTypeAndName & arguments) const;
|
||||
|
||||
String name;
|
||||
ASTPtr function_core;
|
||||
ContextPtr context;
|
||||
};
|
||||
|
||||
}
|
@ -16,7 +16,7 @@ class InDepthNodeVisitor
|
||||
public:
|
||||
using Data = typename Matcher::Data;
|
||||
|
||||
InDepthNodeVisitor(Data & data_, WriteBuffer * ostr_ = nullptr)
|
||||
explicit InDepthNodeVisitor(Data & data_, WriteBuffer * ostr_ = nullptr)
|
||||
: data(data_),
|
||||
visit_depth(0),
|
||||
ostr(ostr_)
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include <Interpreters/InterpreterCreateFunctionQuery.h>
|
||||
#include <Interpreters/FunctionNameNormalizer.h>
|
||||
#include <Interpreters/UserDefinedObjectsLoader.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Interpreters/UserDefinedFunctionFactory.h>
|
||||
#include <Parsers/ASTIdentifier.h>
|
||||
|
||||
namespace DB
|
||||
@ -15,7 +15,7 @@ namespace ErrorCodes
|
||||
{
|
||||
extern const int UNKNOWN_IDENTIFIER;
|
||||
extern const int CANNOT_CREATE_RECURSIVE_FUNCTION;
|
||||
// extern const int UNSUPPORTED_OPERATION;
|
||||
extern const int UNSUPPORTED_METHOD;
|
||||
}
|
||||
|
||||
BlockIO InterpreterCreateFunctionQuery::execute()
|
||||
@ -24,25 +24,24 @@ BlockIO InterpreterCreateFunctionQuery::execute()
|
||||
FunctionNameNormalizer().visit(query_ptr.get());
|
||||
auto * create_function_query = query_ptr->as<ASTCreateFunctionQuery>();
|
||||
|
||||
// if (!create_function_query)
|
||||
// throw Exception(ErrorCodes::UNSUPPORTED_OPERATION, "Expected CREATE FUNCTION query");
|
||||
if (!create_function_query)
|
||||
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Expected CREATE FUNCTION query");
|
||||
|
||||
auto & function_name = create_function_query->function_name;
|
||||
validateFunction(create_function_query->function_core, function_name);
|
||||
|
||||
if (is_internal)
|
||||
{
|
||||
FunctionFactory::instance().registerUserDefinedFunction(*create_function_query);
|
||||
}
|
||||
else
|
||||
UserDefinedFunctionFactory::instance().registerFunction(function_name, query_ptr);
|
||||
|
||||
if (!is_internal)
|
||||
{
|
||||
|
||||
try
|
||||
{
|
||||
UserDefinedObjectsLoader::instance().storeObject(getContext(), UserDefinedObjectType::Function, function_name, *query_ptr);
|
||||
FunctionFactory::instance().registerUserDefinedFunction(*create_function_query);
|
||||
}
|
||||
catch (Exception & e)
|
||||
{
|
||||
UserDefinedFunctionFactory::instance().unregisterFunction(function_name);
|
||||
e.addMessage(fmt::format("while storing user defined function {} on disk", backQuote(function_name)));
|
||||
throw;
|
||||
}
|
||||
@ -56,11 +55,15 @@ void InterpreterCreateFunctionQuery::validateFunction(ASTPtr function, const Str
|
||||
const auto * args_tuple = function->as<ASTFunction>()->arguments->children.at(0)->as<ASTFunction>();
|
||||
std::unordered_set<String> arguments;
|
||||
for (const auto & argument : args_tuple->arguments->children)
|
||||
arguments.insert(argument->as<ASTIdentifier>()->name());
|
||||
{
|
||||
const auto & argument_name = argument->as<ASTIdentifier>()->name();
|
||||
auto [_, inserted] = arguments.insert(argument_name);
|
||||
if (!inserted)
|
||||
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "Identifier {} already used as function parameter", argument_name);
|
||||
}
|
||||
|
||||
std::set<String> identifiers_in_body;
|
||||
ASTPtr function_body = function->as<ASTFunction>()->children.at(0)->children.at(1);
|
||||
getIdentifiers(function_body, identifiers_in_body);
|
||||
std::unordered_set<String> identifiers_in_body = getIdentifiers(function_body);
|
||||
|
||||
for (const auto & identifier : identifiers_in_body)
|
||||
{
|
||||
@ -71,16 +74,29 @@ void InterpreterCreateFunctionQuery::validateFunction(ASTPtr function, const Str
|
||||
validateFunctionRecursiveness(function_body, name);
|
||||
}
|
||||
|
||||
void InterpreterCreateFunctionQuery::getIdentifiers(ASTPtr node, std::set<String> & identifiers)
|
||||
std::unordered_set<String> InterpreterCreateFunctionQuery::getIdentifiers(ASTPtr node)
|
||||
{
|
||||
for (const auto & child : node->children)
|
||||
{
|
||||
auto identifier_name_opt = tryGetIdentifierName(child);
|
||||
if (identifier_name_opt)
|
||||
identifiers.insert(identifier_name_opt.value());
|
||||
std::unordered_set<String> identifiers;
|
||||
|
||||
getIdentifiers(child, identifiers);
|
||||
std::stack<ASTPtr> ast_nodes_to_process;
|
||||
ast_nodes_to_process.push(node);
|
||||
|
||||
while (!ast_nodes_to_process.empty())
|
||||
{
|
||||
auto ast_node_to_process = ast_nodes_to_process.top();
|
||||
ast_nodes_to_process.pop();
|
||||
|
||||
for (const auto & child : ast_node_to_process->children)
|
||||
{
|
||||
auto identifier_name_opt = tryGetIdentifierName(child);
|
||||
if (identifier_name_opt)
|
||||
identifiers.insert(identifier_name_opt.value());
|
||||
|
||||
ast_nodes_to_process.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
return identifiers;
|
||||
}
|
||||
|
||||
void InterpreterCreateFunctionQuery::validateFunctionRecursiveness(ASTPtr node, const String & function_to_create)
|
||||
|
@ -23,7 +23,7 @@ public:
|
||||
|
||||
private:
|
||||
static void validateFunction(ASTPtr function, const String & name);
|
||||
static void getIdentifiers(ASTPtr node, std::set<String> & identifiers);
|
||||
static std::unordered_set<String> getIdentifiers(ASTPtr node);
|
||||
static void validateFunctionRecursiveness(ASTPtr node, const String & function_to_create);
|
||||
|
||||
ASTPtr query_ptr;
|
||||
|
@ -3,7 +3,7 @@
|
||||
#include <Interpreters/FunctionNameNormalizer.h>
|
||||
#include <Interpreters/InterpreterDropFunctionQuery.h>
|
||||
#include <Interpreters/UserDefinedObjectsLoader.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Interpreters/UserDefinedFunctionFactory.h>
|
||||
#include <Parsers/ASTDropFunctionQuery.h>
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ BlockIO InterpreterDropFunctionQuery::execute()
|
||||
getContext()->checkAccess(AccessType::DROP_FUNCTION);
|
||||
FunctionNameNormalizer().visit(query_ptr.get());
|
||||
auto & drop_function_query = query_ptr->as<ASTDropFunctionQuery &>();
|
||||
FunctionFactory::instance().unregisterUserDefinedFunction(drop_function_query.function_name);
|
||||
UserDefinedFunctionFactory::instance().unregisterFunction(drop_function_query.function_name);
|
||||
UserDefinedObjectsLoader::instance().removeObject(getContext(), UserDefinedObjectType::Function, drop_function_query.function_name);
|
||||
|
||||
return {};
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include <Interpreters/CollectJoinOnKeysVisitor.h>
|
||||
#include <Interpreters/RequiredSourceColumnsVisitor.h>
|
||||
#include <Interpreters/GetAggregatesVisitor.h>
|
||||
#include <Interpreters/UserDefinedFunctionsVisitor.h>
|
||||
#include <Interpreters/TableJoin.h>
|
||||
#include <Interpreters/ExpressionActions.h> /// getSmallestColumn()
|
||||
#include <Interpreters/getTableExpressions.h>
|
||||
@ -1045,6 +1046,9 @@ TreeRewriterResultPtr TreeRewriter::analyze(
|
||||
void TreeRewriter::normalize(
|
||||
ASTPtr & query, Aliases & aliases, const NameSet & source_columns_set, bool ignore_alias, const Settings & settings, bool allow_self_aliases)
|
||||
{
|
||||
UserDefinedFunctionsVisitor::Data data_user_defined_functions_visitor;
|
||||
UserDefinedFunctionsVisitor(data_user_defined_functions_visitor).visit(query);
|
||||
|
||||
CustomizeCountDistinctVisitor::Data data_count_distinct{settings.count_distinct_implementation};
|
||||
CustomizeCountDistinctVisitor(data_count_distinct).visit(query);
|
||||
|
||||
|
82
src/Interpreters/UserDefinedFunctionFactory.cpp
Normal file
82
src/Interpreters/UserDefinedFunctionFactory.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
#include "UserDefinedFunctionFactory.h"
|
||||
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int FUNCTION_ALREADY_EXISTS;
|
||||
extern const int UNKNOWN_FUNCTION;
|
||||
extern const int CANNOT_DROP_SYSTEM_FUNCTION;
|
||||
}
|
||||
|
||||
UserDefinedFunctionFactory & UserDefinedFunctionFactory::instance()
|
||||
{
|
||||
static UserDefinedFunctionFactory result;
|
||||
return result;
|
||||
}
|
||||
|
||||
void UserDefinedFunctionFactory::registerFunction(const String & function_name, ASTPtr create_function_query)
|
||||
{
|
||||
if (FunctionFactory::instance().hasNameOrAlias(function_name))
|
||||
throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "The function '{}' already exists", function_name);
|
||||
|
||||
if (AggregateFunctionFactory::instance().hasNameOrAlias(function_name))
|
||||
throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "The aggregate function '{}' already exists", function_name);
|
||||
|
||||
auto [_, inserted] = function_name_to_create_query.emplace(function_name, std::move(create_function_query));
|
||||
if (!inserted)
|
||||
throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS,
|
||||
"The function name '{}' is not unique",
|
||||
function_name);
|
||||
}
|
||||
|
||||
void UserDefinedFunctionFactory::unregisterFunction(const String & function_name)
|
||||
{
|
||||
if (FunctionFactory::instance().hasNameOrAlias(function_name) ||
|
||||
AggregateFunctionFactory::instance().hasNameOrAlias(function_name))
|
||||
throw Exception(ErrorCodes::CANNOT_DROP_SYSTEM_FUNCTION, "Cannot drop system function '{}'", function_name);
|
||||
|
||||
|
||||
auto it = function_name_to_create_query.find(function_name);
|
||||
if (it == function_name_to_create_query.end())
|
||||
throw Exception(ErrorCodes::UNKNOWN_FUNCTION,
|
||||
"The function name '{}' is not registered",
|
||||
function_name);
|
||||
}
|
||||
|
||||
ASTPtr UserDefinedFunctionFactory::get(const String & function_name) const
|
||||
{
|
||||
auto it = function_name_to_create_query.find(function_name);
|
||||
if (it == function_name_to_create_query.end())
|
||||
throw Exception(ErrorCodes::UNKNOWN_FUNCTION,
|
||||
"The function name '{}' is not registered",
|
||||
function_name);
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
ASTPtr UserDefinedFunctionFactory::tryGet(const std::string & function_name) const
|
||||
{
|
||||
auto it = function_name_to_create_query.find(function_name);
|
||||
if (it == function_name_to_create_query.end())
|
||||
return nullptr;
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::vector<std::string> UserDefinedFunctionFactory::getAllRegisteredNames() const
|
||||
{
|
||||
std::vector<std::string> registered_names;
|
||||
registered_names.reserve(function_name_to_create_query.size());
|
||||
|
||||
for (const auto & [name, _] : function_name_to_create_query)
|
||||
registered_names.emplace_back(name);
|
||||
|
||||
return registered_names;
|
||||
}
|
||||
|
||||
}
|
32
src/Interpreters/UserDefinedFunctionFactory.h
Normal file
32
src/Interpreters/UserDefinedFunctionFactory.h
Normal file
@ -0,0 +1,32 @@
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include <Common/NamePrompter.h>
|
||||
|
||||
#include <Parsers/ASTCreateFunctionQuery.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class UserDefinedFunctionFactory : public IHints<1, UserDefinedFunctionFactory>
|
||||
{
|
||||
public:
|
||||
static UserDefinedFunctionFactory & instance();
|
||||
|
||||
void registerFunction(const String & function_name, ASTPtr create_function_query);
|
||||
|
||||
void unregisterFunction(const String & function_name);
|
||||
|
||||
ASTPtr get(const String & function_name) const;
|
||||
|
||||
ASTPtr tryGet(const String & function_name) const;
|
||||
|
||||
std::vector<String> getAllRegisteredNames() const override;
|
||||
|
||||
private:
|
||||
|
||||
std::unordered_map<String, ASTPtr> function_name_to_create_query;
|
||||
};
|
||||
|
||||
}
|
99
src/Interpreters/UserDefinedFunctionsVisitor.cpp
Normal file
99
src/Interpreters/UserDefinedFunctionsVisitor.cpp
Normal file
@ -0,0 +1,99 @@
|
||||
#include "UserDefinedFunctionsVisitor.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <stack>
|
||||
|
||||
#include <Parsers/ASTFunction.h>
|
||||
#include <Parsers/ASTCreateFunctionQuery.h>
|
||||
#include <Parsers/ASTExpressionList.h>
|
||||
#include <Parsers/ASTIdentifier.h>
|
||||
#include <Interpreters/UserDefinedFunctionFactory.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int UNSUPPORTED_METHOD;
|
||||
}
|
||||
|
||||
void UserDefinedFunctionsMatcher::visit(ASTPtr & ast, Data &)
|
||||
{
|
||||
auto * function = ast->as<ASTFunction>();
|
||||
if (!function)
|
||||
return;
|
||||
|
||||
auto result = tryToReplaceFunction(*function);
|
||||
if (result)
|
||||
ast = result;
|
||||
}
|
||||
|
||||
bool UserDefinedFunctionsMatcher::needChildVisit(const ASTPtr &, const ASTPtr &)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
ASTPtr UserDefinedFunctionsMatcher::tryToReplaceFunction(const ASTFunction & function)
|
||||
{
|
||||
auto user_defined_function = UserDefinedFunctionFactory::instance().tryGet(function.name);
|
||||
if (!user_defined_function)
|
||||
return nullptr;
|
||||
|
||||
const auto & function_arguments_list = function.children.at(0)->as<ASTExpressionList>();
|
||||
auto & function_arguments = function_arguments_list->children;
|
||||
|
||||
const auto & create_function_query = user_defined_function->as<ASTCreateFunctionQuery>();
|
||||
auto & function_core_expression = create_function_query->function_core->children.at(0);
|
||||
|
||||
const auto & identifiers_expression_list = function_core_expression->children.at(0)->children.at(0)->as<ASTExpressionList>();
|
||||
const auto & identifiers_raw = identifiers_expression_list->children;
|
||||
|
||||
if (function_arguments.size() != identifiers_raw.size())
|
||||
throw Exception(ErrorCodes::UNSUPPORTED_METHOD,
|
||||
"Function {} expects {} arguments actual arguments {}",
|
||||
create_function_query->function_name,
|
||||
identifiers_raw.size(),
|
||||
function_arguments.size());
|
||||
|
||||
std::unordered_map<std::string, ASTPtr> identifier_name_to_function_argument;
|
||||
|
||||
for (size_t parameter_index = 0; parameter_index < identifiers_raw.size(); ++parameter_index)
|
||||
{
|
||||
const auto & identifier = identifiers_raw[parameter_index]->as<ASTIdentifier>();
|
||||
const auto & function_argument = function_arguments[parameter_index];
|
||||
const auto & identifier_name = identifier->name();
|
||||
|
||||
identifier_name_to_function_argument.emplace(identifier_name, function_argument);
|
||||
}
|
||||
|
||||
auto function_body_to_update = function_core_expression->children.at(1)->clone();
|
||||
|
||||
std::stack<ASTPtr> ast_nodes_to_update;
|
||||
ast_nodes_to_update.push(function_body_to_update);
|
||||
|
||||
while (!ast_nodes_to_update.empty())
|
||||
{
|
||||
auto ast_node_to_update = ast_nodes_to_update.top();
|
||||
ast_nodes_to_update.pop();
|
||||
|
||||
for (auto & child : ast_node_to_update->children)
|
||||
{
|
||||
auto identifier_name_opt = tryGetIdentifierName(child);
|
||||
if (identifier_name_opt)
|
||||
{
|
||||
auto function_argument_it = identifier_name_to_function_argument.find(*identifier_name_opt);
|
||||
assert(function_argument_it != identifier_name_to_function_argument.end());
|
||||
|
||||
child = function_argument_it->second->clone();
|
||||
continue;
|
||||
}
|
||||
|
||||
ast_nodes_to_update.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
return function_body_to_update;
|
||||
}
|
||||
|
||||
}
|
44
src/Interpreters/UserDefinedFunctionsVisitor.h
Normal file
44
src/Interpreters/UserDefinedFunctionsVisitor.h
Normal file
@ -0,0 +1,44 @@
|
||||
#pragma once
|
||||
|
||||
#include <Interpreters/Aliases.h>
|
||||
#include <Interpreters/InDepthNodeVisitor.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class ASTFunction;
|
||||
|
||||
/** Visits ASTFunction nodes and if it is used defined function replace it with function body.
|
||||
* Example:
|
||||
*
|
||||
* CREATE FUNCTION test_function AS a -> a + 1;
|
||||
*
|
||||
* Before applying visitor:
|
||||
* SELECT test_function(number) FROM system.numbers LIMIT 10;
|
||||
*
|
||||
* After applying visitor:
|
||||
* SELECT number + 1 FROM system.numbers LIMIT 10;
|
||||
*/
|
||||
class UserDefinedFunctionsMatcher
|
||||
{
|
||||
public:
|
||||
using Visitor = InDepthNodeVisitor<UserDefinedFunctionsMatcher, true>;
|
||||
|
||||
struct Data
|
||||
{
|
||||
};
|
||||
|
||||
static void visit(ASTPtr & ast, Data & data);
|
||||
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child);
|
||||
|
||||
private:
|
||||
static void visit(ASTFunction & func, const Data & data);
|
||||
|
||||
static ASTPtr tryToReplaceFunction(const ASTFunction & function);
|
||||
|
||||
};
|
||||
|
||||
/// Visits AST nodes and collect their aliases in one map (with links to source nodes).
|
||||
using UserDefinedFunctionsVisitor = UserDefinedFunctionsMatcher::Visitor;
|
||||
|
||||
}
|
@ -47,9 +47,10 @@ void UserDefinedObjectsLoader::loadUserDefinedObject(ContextPtr context, UserDef
|
||||
auto name_ref = StringRef(name.data(), name.size());
|
||||
LOG_DEBUG(log, "Loading user defined object {} from file {}", backQuote(name_ref), path);
|
||||
|
||||
String object_create_query;
|
||||
/// There is .sql file with user defined object creation statement.
|
||||
ReadBufferFromFile in(path, 1024);
|
||||
ReadBufferFromFile in(path);
|
||||
|
||||
String object_create_query;
|
||||
readStringUntilEOF(object_create_query, in);
|
||||
|
||||
try
|
||||
@ -119,7 +120,6 @@ void UserDefinedObjectsLoader::storeObject(ContextPtr context, UserDefinedObject
|
||||
}
|
||||
}
|
||||
|
||||
std::cerr << "UserDefinedObjectsLoader::storeObject " << file_path << std::endl;
|
||||
if (std::filesystem::exists(file_path))
|
||||
throw Exception(ErrorCodes::OBJECT_ALREADY_STORED_ON_DISK, "User defined object {} already stored on disk", backQuote(file_path));
|
||||
|
||||
@ -145,17 +145,16 @@ void UserDefinedObjectsLoader::removeObject(ContextPtr context, UserDefinedObjec
|
||||
String dir_path = context->getPath() + "user_defined/";
|
||||
LOG_DEBUG(log, "Removing file for user defined object {} from {}", backQuote(object_name), dir_path);
|
||||
|
||||
String file_path_name;
|
||||
std::filesystem::path file_path;
|
||||
|
||||
switch (object_type)
|
||||
{
|
||||
case UserDefinedObjectType::Function:
|
||||
{
|
||||
file_path_name = dir_path + "function_" + escapeForFileName(object_name) + ".sql";
|
||||
file_path = dir_path + "function_" + escapeForFileName(object_name) + ".sql";
|
||||
}
|
||||
}
|
||||
|
||||
std::filesystem::path file_path(file_path_name);
|
||||
if (!std::filesystem::exists(file_path))
|
||||
throw Exception(ErrorCodes::OBJECT_WAS_NOT_STORED_ON_DISK, "User defined object {} was not stored on disk", backQuote(file_path.string()));
|
||||
|
||||
|
@ -1,26 +1,38 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Parsers/queryToString.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/IFunction.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Interpreters/UserDefinedFunctionFactory.h>
|
||||
#include <Storages/System/StorageSystemFunctions.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace
|
||||
{
|
||||
template <typename Factory>
|
||||
void fillRow(MutableColumns & res_columns, const String & name, UInt64 is_aggregate, const Factory & f)
|
||||
void fillRow(MutableColumns & res_columns, const String & name, UInt64 is_aggregate, const String & create_query, const Factory & f)
|
||||
{
|
||||
res_columns[0]->insert(name);
|
||||
res_columns[1]->insert(is_aggregate);
|
||||
res_columns[2]->insert(f.isCaseInsensitive(name));
|
||||
if (f.isAlias(name))
|
||||
res_columns[3]->insert(f.aliasTo(name));
|
||||
else
|
||||
|
||||
if constexpr (std::is_same_v<Factory, UserDefinedFunctionFactory>)
|
||||
{
|
||||
res_columns[2]->insert(false);
|
||||
res_columns[3]->insertDefault();
|
||||
}
|
||||
else
|
||||
{
|
||||
res_columns[2]->insert(f.isCaseInsensitive(name));
|
||||
if (f.isAlias(name))
|
||||
res_columns[3]->insert(f.aliasTo(name));
|
||||
else
|
||||
res_columns[3]->insertDefault();
|
||||
}
|
||||
|
||||
res_columns[4]->insert(create_query);
|
||||
}
|
||||
}
|
||||
|
||||
@ -31,6 +43,7 @@ NamesAndTypesList StorageSystemFunctions::getNamesAndTypes()
|
||||
{"is_aggregate", std::make_shared<DataTypeUInt8>()},
|
||||
{"case_insensitive", std::make_shared<DataTypeUInt8>()},
|
||||
{"alias_to", std::make_shared<DataTypeString>()},
|
||||
{"create_query", std::make_shared<DataTypeString>()}
|
||||
};
|
||||
}
|
||||
|
||||
@ -40,14 +53,22 @@ void StorageSystemFunctions::fillData(MutableColumns & res_columns, ContextPtr,
|
||||
const auto & function_names = functions_factory.getAllRegisteredNames();
|
||||
for (const auto & function_name : function_names)
|
||||
{
|
||||
fillRow(res_columns, function_name, UInt64(0), functions_factory);
|
||||
fillRow(res_columns, function_name, UInt64(0), "", functions_factory);
|
||||
}
|
||||
|
||||
const auto & aggregate_functions_factory = AggregateFunctionFactory::instance();
|
||||
const auto & aggregate_function_names = aggregate_functions_factory.getAllRegisteredNames();
|
||||
for (const auto & function_name : aggregate_function_names)
|
||||
{
|
||||
fillRow(res_columns, function_name, UInt64(1), aggregate_functions_factory);
|
||||
fillRow(res_columns, function_name, UInt64(1), "", aggregate_functions_factory);
|
||||
}
|
||||
|
||||
const auto & user_defined_functions_factory = UserDefinedFunctionFactory::instance();
|
||||
const auto & user_defined_functions_names = user_defined_functions_factory.getAllRegisteredNames();
|
||||
for (const auto & function_name : user_defined_functions_names)
|
||||
{
|
||||
auto create_query = queryToString(user_defined_functions_factory.get(function_name));
|
||||
fillRow(res_columns, function_name, UInt64(0), create_query, user_defined_functions_factory);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,13 +1,17 @@
|
||||
create function MyFunc2 as (a, b) -> a || b || c; -- { serverError 47 }
|
||||
-- CREATE FUNCTION MyFunc2 AS (a, b) -> a || b || c; --{serverError 47}
|
||||
|
||||
create function MyFunc2 as (a, b) -> MyFunc2(a, b) + MyFunc2(a, b); -- { serverError 593 } recursive function
|
||||
-- CREATE FUNCTION MyFunc2 AS (a, b) -> MyFunc2(a, b) + MyFunc2(a, b); --{serverError 600}
|
||||
|
||||
create function cast as a -> a + 1; -- { serverError 591 } function already exist
|
||||
-- CREATE FUNCTION cast AS a -> a + 1; --{serverError 598}
|
||||
|
||||
create function sum as (a, b) -> a + b; -- { serverError 591 } aggregate function already exist
|
||||
-- CREATE FUNCTION sum AS (a, b) -> a + b; --{serverError 598}
|
||||
|
||||
create function MyFunc3 as (a, b) -> a + b;
|
||||
-- CREATE FUNCTION MyFunc3 AS (a, b) -> a + b;
|
||||
|
||||
create function MyFunc3 as (a) -> a || '!!!'; -- { serverError 591 } function already exist
|
||||
-- CREATE FUNCTION MyFunc3 AS (a) -> a || '!!!'; --{serverError 598}
|
||||
|
||||
drop function MyFunc3;
|
||||
-- DROP FUNCTION MyFunc3;
|
||||
|
||||
-- DROP FUNCTION unknownFunc; -- {serverError 46}
|
||||
|
||||
DROP FUNCTION CAST; -- {serverError 599}
|
||||
|
@ -1,3 +0,0 @@
|
||||
0
|
||||
1
|
||||
0
|
@ -1,5 +0,0 @@
|
||||
create function MyFunc as (a, b, c) -> a + b > c AND c < 10;
|
||||
select MyFunc(1, 2, 3);
|
||||
select MyFunc(2, 2, 3);
|
||||
select MyFunc(20, 20, 11);
|
||||
drop function MyFunc;
|
@ -1,2 +0,0 @@
|
||||
drop function unknownFunc; -- { serverError 46 }
|
||||
drop function CAST; -- { serverError 592 }
|
Loading…
Reference in New Issue
Block a user