Add feature: create user defined function as lambda

This commit is contained in:
ANDREI STAROVEROV 2021-05-09 12:47:29 +03:00
parent 6f08f945e8
commit 13a7169669
15 changed files with 308 additions and 1 deletions

View File

@ -1,4 +1,5 @@
#include <Functions/FunctionFactory.h>
#include <Functions/UserDefinedFunction.h>
#include <Interpreters/Context.h>
@ -133,4 +134,19 @@ FunctionFactory & FunctionFactory::instance()
return ret;
}
void FunctionFactory::registerUserDefinedFunction(
const ASTCreateFunctionQuery & create_function_query,
CaseSensitiveness case_sensitiveness)
{
registerFunction(create_function_query.function_name, [create_function_query](ContextPtr context)
{
auto function = UserDefinedFunction::create(context);
function->setName(create_function_query.function_name);
function->setFunctionCore(create_function_query.function_core);
FunctionOverloadResolverImplPtr res = std::make_unique<DefaultOverloadResolver>(function);
return res;
}, case_sensitiveness);
}
}

View File

@ -1,8 +1,10 @@
#pragma once
#include <Functions/IFunctionAdaptors.h>
#include <Functions/UserDefinedFunction.h>
#include <Interpreters/Context_fwd.h>
#include <Common/IFactoryWithAliases.h>
#include <Parsers/ASTCreateFunctionQuery.h>
#include <functional>
#include <memory>
@ -13,6 +15,8 @@
namespace DB
{
class UserDefinedFunction;
/** Creates function by name.
* Function could use for initialization (take ownership of shared_ptr, for example)
* some dictionaries from Context.
@ -38,6 +42,10 @@ public:
registerFunction(name, &Function::create, case_sensitiveness);
}
void registerUserDefinedFunction(
const ASTCreateFunctionQuery & create_function_query,
CaseSensitiveness case_sensitiveness = CaseSensitive);
/// This function is used by YQL - internal Yandex product that depends on ClickHouse by source code.
std::vector<std::string> getAllNames() const;

View File

@ -0,0 +1,90 @@
#include <DataTypes/DataTypeFactory.h>
#include <DataTypes/DataTypeString.h>
#include <Functions/UserDefinedFunction.h>
#include <Interpreters/TreeRewriter.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/ExpressionAnalyzer.h>
#include <Parsers/ASTIdentifier.h>
namespace DB
{
namespace ErrorCodes
{
extern const int TYPE_MISMATCH;
}
UserDefinedFunction::UserDefinedFunction(ContextPtr context_)
: function_core(nullptr)
, context(context_)
{}
UserDefinedFunctionPtr UserDefinedFunction::create(ContextPtr context)
{
return std::make_shared<UserDefinedFunction>(context);
}
String UserDefinedFunction::getName() const
{
return name;
}
ColumnPtr UserDefinedFunction::executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t) const
{
Block block = executeCore(arguments);
String result_name = function_core->as<ASTFunction>()->arguments->children.at(1)->getColumnName();
// result of function executing was inserted in the end
return block.getColumns().back();
}
size_t UserDefinedFunction::getNumberOfArguments() const
{
return function_core->as<ASTFunction>()->arguments->children[0]->size() - 2;
}
void UserDefinedFunction::setName(const String & name_)
{
name = name_;
}
void UserDefinedFunction::setFunctionCore(ASTPtr function_core_)
{
function_core = function_core_;
}
DataTypePtr UserDefinedFunction::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
{
Block block = executeCore(arguments);
return block.getDataTypes().back();
}
Block UserDefinedFunction::executeCore(const ColumnsWithTypeAndName & arguments) const
{
const auto * lambda_args_tuple = function_core->as<ASTFunction>()->arguments->children.at(0)->as<ASTFunction>();
const ASTs & lambda_arg_asts = lambda_args_tuple->arguments->children;
NamesAndTypesList lambda_arguments;
Block block;
for (size_t j = 0; j < lambda_arg_asts.size(); ++j)
{
auto opt_arg_name = tryGetIdentifierName(lambda_arg_asts[j]);
if (!opt_arg_name)
throw Exception("lambda argument declarations must be identifiers", ErrorCodes::TYPE_MISMATCH);
lambda_arguments.emplace_back(*opt_arg_name, arguments[j].type);
block.insert({arguments[j].column, arguments[j].type, *opt_arg_name});
}
ASTPtr lambda_body = function_core->as<ASTFunction>()->children.at(0)->children.at(1);
auto syntax_result = TreeRewriter(context).analyze(lambda_body, lambda_arguments);
ExpressionAnalyzer analyzer(lambda_body, syntax_result, context);
ExpressionActionsPtr actions = analyzer.getActions(false);
actions->execute(block);
return block;
}
}

View File

@ -0,0 +1,37 @@
#pragma once
#include <Functions/FunctionFactory.h>
#include <Functions/IFunctionImpl.h>
#include <Interpreters/Context.h>
#include <Parsers/ASTCreateFunctionQuery.h>
namespace DB
{
class UserDefinedFunction;
using UserDefinedFunctionPtr = std::shared_ptr<UserDefinedFunction>;
class UserDefinedFunction : public IFunction
{
public:
explicit UserDefinedFunction(ContextPtr context_);
static UserDefinedFunctionPtr create(ContextPtr context);
String getName() const override;
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override;
size_t getNumberOfArguments() const override;
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override;
void setName(const String & name_);
void setFunctionCore(ASTPtr function_core_);
private:
Block executeCore(const ColumnsWithTypeAndName & arguments) const;
private:
String name;
ASTPtr function_core;
ContextPtr context;
};
}

View File

@ -104,6 +104,7 @@ SRCS(
URL/registerFunctionsURL.cpp
URL/tldLookup.generated.cpp
URL/topLevelDomain.cpp
UserDefinedFunction.cpp
abs.cpp
acos.cpp
acosh.cpp

View File

@ -0,0 +1,18 @@
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/ExpressionAnalyzer.h>
#include <Interpreters/InterpreterCreateFunctionQuery.h>
#include <Interpreters/FunctionNameNormalizer.h>
#include <Functions/FunctionFactory.h>
namespace DB
{
BlockIO InterpreterCreateFunctionQuery::execute()
{
FunctionNameNormalizer().visit(query_ptr.get());
auto & create_function_query = query_ptr->as<ASTCreateFunctionQuery &>();
FunctionFactory::instance().registerUserDefinedFunction(create_function_query);
return {};
}
}

View File

@ -0,0 +1,22 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/ASTCreateFunctionQuery.h>
namespace DB
{
class ASTCreateFunctionQuery;
class InterpreterCreateFunctionQuery : public IInterpreter, WithContext
{
public:
InterpreterCreateFunctionQuery(const ASTPtr & query_ptr_, ContextPtr context_) : WithContext(context_), query_ptr(query_ptr_) {}
BlockIO execute() override;
private:
ASTPtr query_ptr;
};
}

View File

@ -33,6 +33,7 @@
#include <Interpreters/Context.h>
#include <Interpreters/InterpreterAlterQuery.h>
#include <Interpreters/InterpreterCheckQuery.h>
#include <Interpreters/InterpreterCreateFunctionQuery.h>
#include <Interpreters/InterpreterCreateQuery.h>
#include <Interpreters/InterpreterCreateQuotaQuery.h>
#include <Interpreters/InterpreterCreateRoleQuery.h>
@ -264,6 +265,10 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, ContextPtr
{
return std::make_unique<InterpreterExternalDDLQuery>(query, context);
}
else if (query->as<ASTCreateFunctionQuery>())
{
return std::make_unique<InterpreterCreateFunctionQuery>(query, context);
}
else
{
throw Exception("Unknown type of query: " + query->getID(), ErrorCodes::UNKNOWN_TYPE_OF_QUERY);

View File

@ -54,7 +54,7 @@ SRCS(
ExpressionAnalyzer.cpp
ExternalDictionariesLoader.cpp
ExternalLoader.cpp
ExternalLoaderDictionaryStorageConfigRepository.cpp
ExternalLoaderDatabaseConfigRepository.cpp
ExternalLoaderTempConfigRepository.cpp
ExternalLoaderXMLConfigRepository.cpp
ExternalModelsLoader.cpp
@ -70,6 +70,7 @@ SRCS(
InternalTextLogsQueue.cpp
InterpreterAlterQuery.cpp
InterpreterCheckQuery.cpp
InterpreterCreateFunctionQuery.cpp
InterpreterCreateQuery.cpp
InterpreterCreateQuotaQuery.cpp
InterpreterCreateRoleQuery.cpp

View File

@ -0,0 +1,21 @@
#include <Common/quoteString.h>
#include <IO/Operators.h>
#include <Parsers/ASTCreateFunctionQuery.h>
namespace DB
{
ASTPtr ASTCreateFunctionQuery::clone() const
{
return std::make_shared<ASTCreateFunctionQuery>(*this);
}
void ASTCreateFunctionQuery::formatImpl(const IAST::FormatSettings & settings, IAST::FormatState & state, IAST::FormatStateStacked frame) const
{
settings.ostr << (settings.hilite ? hilite_keyword : "") << "CREATE FUNCTION " << (settings.hilite ? hilite_none : "");
settings.ostr << (settings.hilite ? hilite_identifier : "") << backQuoteIfNeed(function_name) << (settings.hilite ? hilite_none : "");
settings.ostr << (settings.hilite ? hilite_keyword : "") << " AS " << (settings.hilite ? hilite_none : "");
function_core->formatImpl(settings, state, frame);
}
}

View File

@ -0,0 +1,22 @@
#pragma once
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTFunction.h>
namespace DB
{
class ASTCreateFunctionQuery : public IAST
{
public:
String function_name;
ASTPtr function_core;
String getID(char) const override { return "CreateFunctionQuery"; }
ASTPtr clone() const override;
void formatImpl(const FormatSettings & s, FormatState & state, FormatStateStacked frame) const override;
};
}

View File

@ -0,0 +1,46 @@
#include <Parsers/ASTCreateFunctionQuery.h>
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/CommonParsers.h>
#include <Parsers/ExpressionElementParsers.h>
#include <Parsers/ExpressionListParsers.h>
#include <Parsers/ParserCreateFunctionQuery.h>
namespace DB
{
bool ParserCreateFunctionQuery::parseImpl(IParser::Pos & pos, ASTPtr & node, Expected & expected)
{
ParserKeyword s_create("CREATE");
ParserKeyword s_function("FUNCTION");
ParserIdentifier function_name_p;
ParserKeyword s_as("AS");
ParserLambdaExpression lambda_p;
ASTPtr function_name;
ASTPtr function_core;
if (!s_create.ignore(pos, expected))
return false;
if (!s_function.ignore(pos, expected))
return false;
if (!function_name_p.parse(pos, function_name, expected))
return false;
if (!s_as.ignore(pos, expected))
return false;
if (!lambda_p.parse(pos, function_core, expected))
return false;
auto create_function_query = std::make_shared<ASTCreateFunctionQuery>();
node = create_function_query;
create_function_query->function_name = function_name->as<ASTIdentifier &>().name();
create_function_query->function_core = function_core;
return true;
}
}

View File

@ -0,0 +1,15 @@
#pragma once
#include "IParserBase.h"
namespace DB
{
/// CREATE FUNCTION test AS x -> x || '1'
class ParserCreateFunctionQuery : public IParserBase
{
protected:
const char * getName() const override { return "CREATE FUNCTION query"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
}

View File

@ -1,4 +1,5 @@
#include <Parsers/ParserAlterQuery.h>
#include <Parsers/ParserCreateFunctionQuery.h>
#include <Parsers/ParserCreateQuery.h>
#include <Parsers/ParserCreateQuotaQuery.h>
#include <Parsers/ParserCreateRoleQuery.h>
@ -36,6 +37,7 @@ bool ParserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
ParserCreateQuotaQuery create_quota_p;
ParserCreateRowPolicyQuery create_row_policy_p;
ParserCreateSettingsProfileQuery create_settings_profile_p;
ParserCreateFunctionQuery create_function_p;
ParserDropAccessEntityQuery drop_access_entity_p;
ParserGrantQuery grant_p;
ParserSetRoleQuery set_role_p;
@ -52,6 +54,7 @@ bool ParserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
|| create_quota_p.parse(pos, node, expected)
|| create_row_policy_p.parse(pos, node, expected)
|| create_settings_profile_p.parse(pos, node, expected)
|| create_function_p.parse(pos, node, expected)
|| drop_access_entity_p.parse(pos, node, expected)
|| grant_p.parse(pos, node, expected)
|| external_ddl_p.parse(pos, node, expected);

View File

@ -15,6 +15,7 @@ SRCS(
ASTColumnsMatcher.cpp
ASTColumnsTransformers.cpp
ASTConstraintDeclaration.cpp
ASTCreateFunctionQuery.cpp
ASTCreateQuery.cpp
ASTCreateQuotaQuery.cpp
ASTCreateRoleQuery.cpp
@ -86,6 +87,7 @@ SRCS(
ParserAlterQuery.cpp
ParserCase.cpp
ParserCheckQuery.cpp
ParserCreateFunctionQuery.cpp
ParserCreateQuery.cpp
ParserCreateQuotaQuery.cpp
ParserCreateRoleQuery.cpp