From 13a71696692b19dd6210b9d0d1ae58dbc92f7762 Mon Sep 17 00:00:00 2001 From: ANDREI STAROVEROV Date: Sun, 9 May 2021 12:47:29 +0300 Subject: [PATCH] Add feature: create user defined function as lambda --- src/Functions/FunctionFactory.cpp | 16 ++++ src/Functions/FunctionFactory.h | 8 ++ src/Functions/UserDefinedFunction.cpp | 90 +++++++++++++++++++ src/Functions/UserDefinedFunction.h | 37 ++++++++ src/Functions/ya.make | 1 + .../InterpreterCreateFunctionQuery.cpp | 18 ++++ .../InterpreterCreateFunctionQuery.h | 22 +++++ src/Interpreters/InterpreterFactory.cpp | 5 ++ src/Interpreters/ya.make | 3 +- src/Parsers/ASTCreateFunctionQuery.cpp | 21 +++++ src/Parsers/ASTCreateFunctionQuery.h | 22 +++++ src/Parsers/ParserCreateFunctionQuery.cpp | 46 ++++++++++ src/Parsers/ParserCreateFunctionQuery.h | 15 ++++ src/Parsers/ParserQuery.cpp | 3 + src/Parsers/ya.make | 2 + 15 files changed, 308 insertions(+), 1 deletion(-) create mode 100644 src/Functions/UserDefinedFunction.cpp create mode 100644 src/Functions/UserDefinedFunction.h create mode 100644 src/Interpreters/InterpreterCreateFunctionQuery.cpp create mode 100644 src/Interpreters/InterpreterCreateFunctionQuery.h create mode 100644 src/Parsers/ASTCreateFunctionQuery.cpp create mode 100644 src/Parsers/ASTCreateFunctionQuery.h create mode 100644 src/Parsers/ParserCreateFunctionQuery.cpp create mode 100644 src/Parsers/ParserCreateFunctionQuery.h diff --git a/src/Functions/FunctionFactory.cpp b/src/Functions/FunctionFactory.cpp index 35ac9ab647b..7f330d45c37 100644 --- a/src/Functions/FunctionFactory.cpp +++ b/src/Functions/FunctionFactory.cpp @@ -1,4 +1,5 @@ #include +#include #include @@ -133,4 +134,19 @@ FunctionFactory & FunctionFactory::instance() return ret; } +void FunctionFactory::registerUserDefinedFunction( + const ASTCreateFunctionQuery & create_function_query, + CaseSensitiveness case_sensitiveness) +{ + registerFunction(create_function_query.function_name, [create_function_query](ContextPtr context) + { + auto function = UserDefinedFunction::create(context); + function->setName(create_function_query.function_name); + function->setFunctionCore(create_function_query.function_core); + + FunctionOverloadResolverImplPtr res = std::make_unique(function); + return res; + }, case_sensitiveness); +} + } diff --git a/src/Functions/FunctionFactory.h b/src/Functions/FunctionFactory.h index 96238a88420..176178f7593 100644 --- a/src/Functions/FunctionFactory.h +++ b/src/Functions/FunctionFactory.h @@ -1,8 +1,10 @@ #pragma once #include +#include #include #include +#include #include #include @@ -13,6 +15,8 @@ namespace DB { +class UserDefinedFunction; + /** Creates function by name. * Function could use for initialization (take ownership of shared_ptr, for example) * some dictionaries from Context. @@ -38,6 +42,10 @@ public: registerFunction(name, &Function::create, case_sensitiveness); } + void registerUserDefinedFunction( + const ASTCreateFunctionQuery & create_function_query, + CaseSensitiveness case_sensitiveness = CaseSensitive); + /// This function is used by YQL - internal Yandex product that depends on ClickHouse by source code. std::vector getAllNames() const; diff --git a/src/Functions/UserDefinedFunction.cpp b/src/Functions/UserDefinedFunction.cpp new file mode 100644 index 00000000000..b7b4ff8de3e --- /dev/null +++ b/src/Functions/UserDefinedFunction.cpp @@ -0,0 +1,90 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int TYPE_MISMATCH; +} + +UserDefinedFunction::UserDefinedFunction(ContextPtr context_) + : function_core(nullptr) + , context(context_) +{} + +UserDefinedFunctionPtr UserDefinedFunction::create(ContextPtr context) +{ + return std::make_shared(context); +} + +String UserDefinedFunction::getName() const +{ + return name; +} + +ColumnPtr UserDefinedFunction::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const +{ + Block block = executeCore(arguments); + + String result_name = function_core->as()->arguments->children.at(1)->getColumnName(); + + // result of function executing was inserted in the end + return block.getColumns().back(); +} + +size_t UserDefinedFunction::getNumberOfArguments() const +{ + return function_core->as()->arguments->children[0]->size() - 2; +} + +void UserDefinedFunction::setName(const String & name_) +{ + name = name_; +} + +void UserDefinedFunction::setFunctionCore(ASTPtr function_core_) +{ + function_core = function_core_; +} + +DataTypePtr UserDefinedFunction::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const +{ + Block block = executeCore(arguments); + return block.getDataTypes().back(); +} + +Block UserDefinedFunction::executeCore(const ColumnsWithTypeAndName & arguments) const +{ + const auto * lambda_args_tuple = function_core->as()->arguments->children.at(0)->as(); + const ASTs & lambda_arg_asts = lambda_args_tuple->arguments->children; + + NamesAndTypesList lambda_arguments; + Block block; + + for (size_t j = 0; j < lambda_arg_asts.size(); ++j) + { + auto opt_arg_name = tryGetIdentifierName(lambda_arg_asts[j]); + if (!opt_arg_name) + throw Exception("lambda argument declarations must be identifiers", ErrorCodes::TYPE_MISMATCH); + + lambda_arguments.emplace_back(*opt_arg_name, arguments[j].type); + block.insert({arguments[j].column, arguments[j].type, *opt_arg_name}); + } + + ASTPtr lambda_body = function_core->as()->children.at(0)->children.at(1); + auto syntax_result = TreeRewriter(context).analyze(lambda_body, lambda_arguments); + ExpressionAnalyzer analyzer(lambda_body, syntax_result, context); + ExpressionActionsPtr actions = analyzer.getActions(false); + + actions->execute(block); + return block; +} + +} diff --git a/src/Functions/UserDefinedFunction.h b/src/Functions/UserDefinedFunction.h new file mode 100644 index 00000000000..2b519740204 --- /dev/null +++ b/src/Functions/UserDefinedFunction.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include +#include + +namespace DB +{ + +class UserDefinedFunction; +using UserDefinedFunctionPtr = std::shared_ptr; + +class UserDefinedFunction : public IFunction +{ +public: + explicit UserDefinedFunction(ContextPtr context_); + static UserDefinedFunctionPtr create(ContextPtr context); + + String getName() const override; + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override; + size_t getNumberOfArguments() const override; + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override; + + void setName(const String & name_); + void setFunctionCore(ASTPtr function_core_); + +private: + Block executeCore(const ColumnsWithTypeAndName & arguments) const; + +private: + String name; + ASTPtr function_core; + ContextPtr context; +}; + +} diff --git a/src/Functions/ya.make b/src/Functions/ya.make index 2a541369ff4..fef9bae3685 100644 --- a/src/Functions/ya.make +++ b/src/Functions/ya.make @@ -104,6 +104,7 @@ SRCS( URL/registerFunctionsURL.cpp URL/tldLookup.generated.cpp URL/topLevelDomain.cpp + UserDefinedFunction.cpp abs.cpp acos.cpp acosh.cpp diff --git a/src/Interpreters/InterpreterCreateFunctionQuery.cpp b/src/Interpreters/InterpreterCreateFunctionQuery.cpp new file mode 100644 index 00000000000..4fa524534f3 --- /dev/null +++ b/src/Interpreters/InterpreterCreateFunctionQuery.cpp @@ -0,0 +1,18 @@ +#include +#include +#include +#include +#include + +namespace DB +{ + +BlockIO InterpreterCreateFunctionQuery::execute() +{ + FunctionNameNormalizer().visit(query_ptr.get()); + auto & create_function_query = query_ptr->as(); + FunctionFactory::instance().registerUserDefinedFunction(create_function_query); + return {}; +} + +} diff --git a/src/Interpreters/InterpreterCreateFunctionQuery.h b/src/Interpreters/InterpreterCreateFunctionQuery.h new file mode 100644 index 00000000000..81347bcc711 --- /dev/null +++ b/src/Interpreters/InterpreterCreateFunctionQuery.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +namespace DB +{ + +class ASTCreateFunctionQuery; + +class InterpreterCreateFunctionQuery : public IInterpreter, WithContext +{ +public: + InterpreterCreateFunctionQuery(const ASTPtr & query_ptr_, ContextPtr context_) : WithContext(context_), query_ptr(query_ptr_) {} + + BlockIO execute() override; + +private: + ASTPtr query_ptr; +}; + +} diff --git a/src/Interpreters/InterpreterFactory.cpp b/src/Interpreters/InterpreterFactory.cpp index 4af8b6ffa7d..54122292589 100644 --- a/src/Interpreters/InterpreterFactory.cpp +++ b/src/Interpreters/InterpreterFactory.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -264,6 +265,10 @@ std::unique_ptr InterpreterFactory::get(ASTPtr & query, ContextPtr { return std::make_unique(query, context); } + else if (query->as()) + { + return std::make_unique(query, context); + } else { throw Exception("Unknown type of query: " + query->getID(), ErrorCodes::UNKNOWN_TYPE_OF_QUERY); diff --git a/src/Interpreters/ya.make b/src/Interpreters/ya.make index 105e1e11365..5c49c7e8946 100644 --- a/src/Interpreters/ya.make +++ b/src/Interpreters/ya.make @@ -54,7 +54,7 @@ SRCS( ExpressionAnalyzer.cpp ExternalDictionariesLoader.cpp ExternalLoader.cpp - ExternalLoaderDictionaryStorageConfigRepository.cpp + ExternalLoaderDatabaseConfigRepository.cpp ExternalLoaderTempConfigRepository.cpp ExternalLoaderXMLConfigRepository.cpp ExternalModelsLoader.cpp @@ -70,6 +70,7 @@ SRCS( InternalTextLogsQueue.cpp InterpreterAlterQuery.cpp InterpreterCheckQuery.cpp + InterpreterCreateFunctionQuery.cpp InterpreterCreateQuery.cpp InterpreterCreateQuotaQuery.cpp InterpreterCreateRoleQuery.cpp diff --git a/src/Parsers/ASTCreateFunctionQuery.cpp b/src/Parsers/ASTCreateFunctionQuery.cpp new file mode 100644 index 00000000000..0b3991ddc44 --- /dev/null +++ b/src/Parsers/ASTCreateFunctionQuery.cpp @@ -0,0 +1,21 @@ +#include +#include +#include + +namespace DB +{ + +ASTPtr ASTCreateFunctionQuery::clone() const +{ + return std::make_shared(*this); +} + +void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, IAST::FormatState & state, IAST::FormatStateStacked frame) const +{ + settings.ostr << (settings.hilite ? hilite_keyword : "") << "CREATE FUNCTION " << (settings.hilite ? hilite_none : ""); + settings.ostr << (settings.hilite ? hilite_identifier : "") << backQuoteIfNeed(function_name) << (settings.hilite ? hilite_none : ""); + settings.ostr << (settings.hilite ? hilite_keyword : "") << " AS " << (settings.hilite ? hilite_none : ""); + function_core->formatImpl(settings, state, frame); +} + +} diff --git a/src/Parsers/ASTCreateFunctionQuery.h b/src/Parsers/ASTCreateFunctionQuery.h new file mode 100644 index 00000000000..3adddad8fbd --- /dev/null +++ b/src/Parsers/ASTCreateFunctionQuery.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +namespace DB +{ + +class ASTCreateFunctionQuery : public IAST +{ +public: + String function_name; + ASTPtr function_core; + + String getID(char) const override { return "CreateFunctionQuery"; } + + ASTPtr clone() const override; + + void formatImpl(const FormatSettings & s, FormatState & state, FormatStateStacked frame) const override; +}; + +} diff --git a/src/Parsers/ParserCreateFunctionQuery.cpp b/src/Parsers/ParserCreateFunctionQuery.cpp new file mode 100644 index 00000000000..1fcce6cbf45 --- /dev/null +++ b/src/Parsers/ParserCreateFunctionQuery.cpp @@ -0,0 +1,46 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Expected & expected) +{ + ParserKeyword s_create("CREATE"); + ParserKeyword s_function("FUNCTION"); + ParserIdentifier function_name_p; + ParserKeyword s_as("AS"); + ParserLambdaExpression lambda_p; + + ASTPtr function_name; + ASTPtr function_core; + + if (!s_create.ignore(pos, expected)) + return false; + + if (!s_function.ignore(pos, expected)) + return false; + + if (!function_name_p.parse(pos, function_name, expected)) + return false; + + if (!s_as.ignore(pos, expected)) + return false; + + if (!lambda_p.parse(pos, function_core, expected)) + return false; + + auto create_function_query = std::make_shared(); + node = create_function_query; + + create_function_query->function_name = function_name->as().name(); + create_function_query->function_core = function_core; + + return true; +} + +} diff --git a/src/Parsers/ParserCreateFunctionQuery.h b/src/Parsers/ParserCreateFunctionQuery.h new file mode 100644 index 00000000000..a48bbdeb563 --- /dev/null +++ b/src/Parsers/ParserCreateFunctionQuery.h @@ -0,0 +1,15 @@ +#pragma once + +#include "IParserBase.h" + +namespace DB +{ +/// CREATE FUNCTION test AS x -> x || '1' +class ParserCreateFunctionQuery : public IParserBase +{ +protected: + const char * getName() const override { return "CREATE FUNCTION query"; } + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; +}; + +} diff --git a/src/Parsers/ParserQuery.cpp b/src/Parsers/ParserQuery.cpp index 4550bdc8a75..274dc0201b3 100644 --- a/src/Parsers/ParserQuery.cpp +++ b/src/Parsers/ParserQuery.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -36,6 +37,7 @@ bool ParserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) ParserCreateQuotaQuery create_quota_p; ParserCreateRowPolicyQuery create_row_policy_p; ParserCreateSettingsProfileQuery create_settings_profile_p; + ParserCreateFunctionQuery create_function_p; ParserDropAccessEntityQuery drop_access_entity_p; ParserGrantQuery grant_p; ParserSetRoleQuery set_role_p; @@ -52,6 +54,7 @@ bool ParserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) || create_quota_p.parse(pos, node, expected) || create_row_policy_p.parse(pos, node, expected) || create_settings_profile_p.parse(pos, node, expected) + || create_function_p.parse(pos, node, expected) || drop_access_entity_p.parse(pos, node, expected) || grant_p.parse(pos, node, expected) || external_ddl_p.parse(pos, node, expected); diff --git a/src/Parsers/ya.make b/src/Parsers/ya.make index 4bd31cb79de..3b9fcb33e0f 100644 --- a/src/Parsers/ya.make +++ b/src/Parsers/ya.make @@ -15,6 +15,7 @@ SRCS( ASTColumnsMatcher.cpp ASTColumnsTransformers.cpp ASTConstraintDeclaration.cpp + ASTCreateFunctionQuery.cpp ASTCreateQuery.cpp ASTCreateQuotaQuery.cpp ASTCreateRoleQuery.cpp @@ -86,6 +87,7 @@ SRCS( ParserAlterQuery.cpp ParserCase.cpp ParserCheckQuery.cpp + ParserCreateFunctionQuery.cpp ParserCreateQuery.cpp ParserCreateQuotaQuery.cpp ParserCreateRoleQuery.cpp