Added tests

This commit is contained in:
Maksim Kita 2021-09-09 16:47:48 +03:00
parent de20e04dfd
commit 93ecbf3ae4
12 changed files with 182 additions and 90 deletions

View File

@ -37,6 +37,8 @@
#include <Interpreters/interpretSubquery.h>
#include <Interpreters/DatabaseAndTableWithAlias.h>
#include <Interpreters/IdentifierSemantic.h>
#include <Interpreters/UserDefinedExecutableFunctionFactory.h>
namespace DB
{
@ -854,17 +856,21 @@ void ActionsMatcher::visit(const ASTFunction & node, const ASTPtr & ast, Data &
if (AggregateFunctionFactory::instance().isAggregateFunctionName(node.name))
return;
FunctionOverloadResolverPtr function_builder;
try
FunctionOverloadResolverPtr function_builder = UserDefinedExecutableFunctionFactory::instance().tryGet(node.name, data.getContext());
if (!function_builder)
{
function_builder = FunctionFactory::instance().get(node.name, data.getContext());
}
catch (Exception & e)
{
auto hints = AggregateFunctionFactory::instance().getHints(node.name);
if (!hints.empty())
e.addMessage("Or unknown aggregate function " + node.name + ". Maybe you meant: " + toString(hints));
throw;
try
{
function_builder = FunctionFactory::instance().get(node.name, data.getContext());
}
catch (Exception & e)
{
auto hints = AggregateFunctionFactory::instance().getHints(node.name);
if (!hints.empty())
e.addMessage("Or unknown aggregate function " + node.name + ". Maybe you meant: " + toString(hints));
throw;
}
}
Names argument_names;

View File

@ -1,16 +1,18 @@
#include "ExternalUserDefinedExecutableFunctionsLoader.h"
#include <Interpreters/UserDefinedExecutableFunction.h>
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataStreams/ShellCommandSource.h>
#include <DataStreams/formatBlock.h>
#include <DataTypes/DataTypeFactory.h>
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Interpreters/UserDefinedExecutableFunction.h>
#include <Interpreters/UserDefinedExecutableFunctionFactory.h>
namespace DB
{
@ -30,8 +32,6 @@ public:
, config(config_)
, process(std::move(process_))
{
std::cerr << "UserDefinedFunction::UserDefinedFunction " << config.argument_types.size() << " ";
std::cerr << " config format " << config.format << std::endl;
}
String getName() const override { return config.name; }
@ -63,14 +63,6 @@ public:
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
{
std::cerr << "UserDefinedFunction::executeImpl " << input_rows_count << " result type " << result_type->getName();
std::cerr << " arguments " << arguments.size() << std::endl;
for (size_t i = 0; i < arguments.size(); ++i)
{
const auto & argument = arguments[i];
std::cerr << "Index " << i << " structure " << argument.dumpStructure() << std::endl;
}
Block arguments_block(arguments);
ColumnWithTypeAndName result(result_type, "result");
@ -87,13 +79,6 @@ public:
}};
std::vector<ShellCommandSource::SendDataTask> tasks = {std::move(task)};
// auto & buffer = process->out;
// char buffer_data[4096] {};
// size_t read_size = buffer.read(buffer_data, sizeof(buffer_data));
// buffer_data[read_size] = '\0';
// std::cerr << "Buffer data read size " << read_size << " data " << buffer_data << std::endl;
Pipe pipe(std::make_unique<ShellCommandSource>(context, config.format, result_block.cloneEmpty(), std::move(process), nullptr, std::move(tasks)));
QueryPipeline pipeline;
@ -101,21 +86,12 @@ public:
PullingPipelineExecutor executor(pipeline);
// std::cerr << "Executor pull blocks" << std::endl;
auto result_column = result_type->createColumn();
Block block;
while (executor.pull(block))
{
std::cerr << "Executor pull block " << block.rows() << std::endl;
result_column->insertFrom(*block.safeGetByPosition(0).column, block.rows());
}
std::cerr << "Result column size " << result_column->size() << std::endl;
Field value;
for (size_t i = 0; i < result_column->size(); ++i)
{
result_column->get(i, value);
std::cerr << "Index " << i << " value " << value.dump() << std::endl;
const auto & result_column_to_add = *block.safeGetByPosition(0).column;
result_column->insertRangeFrom(result_column_to_add, 0, result_column_to_add.size());
}
size_t result_column_size = result_column->size();
@ -159,8 +135,6 @@ ExternalLoader::LoadablePtr ExternalUserDefinedExecutableFunctionsLoader::create
const std::string & key_in_config,
const std::string &) const
{
std::cerr << "ExternalUserDefinedExecutableFunctionsLoader::create name " << name << " key in config " << key_in_config << std::endl;
String command = config.getString(key_in_config + ".command");
String format = config.getString(key_in_config + ".format");
DataTypePtr result_type = DataTypeFactory::instance().get(config.getString(key_in_config + ".return_type"));
@ -190,9 +164,14 @@ ExternalLoader::LoadablePtr ExternalUserDefinedExecutableFunctionsLoader::create
.result_type = std::move(result_type),
};
auto function = std::make_shared<UserDefinedExecutableFunction>(function_config, lifetime);
std::shared_ptr<scope_guard> function_deregister_ptr = std::make_shared<scope_guard>([function_name = function_config.name]()
{
UserDefinedExecutableFunctionFactory::instance().unregisterFunction(function_name);
});
FunctionFactory::instance().registerFunction(function_config.name, [function](ContextPtr function_context)
auto function = std::make_shared<UserDefinedExecutableFunction>(function_config, std::move(function_deregister_ptr), lifetime);
UserDefinedExecutableFunctionFactory::instance().registerFunction(function_config.name, [function](ContextPtr function_context)
{
auto shell_command = ShellCommand::execute(function->getConfig().script_path);
std::shared_ptr<UserDefinedFunction> user_defined_function = std::make_shared<UserDefinedFunction>(function_context, function->getConfig(), std::move(shell_command));

View File

@ -25,8 +25,6 @@ public:
UserDefinedExecutableFunctionPtr tryGetUserDefinedFunction(const std::string & user_defined_function_name) const;
static void resetAll();
protected:
LoadablePtr create(const std::string & name,
const Poco::Util::AbstractConfiguration & config,

View File

@ -13,8 +13,10 @@ namespace DB
UserDefinedExecutableFunction::UserDefinedExecutableFunction(
const Config & config_,
std::shared_ptr<scope_guard> function_deregister_,
const ExternalLoadableLifetime & lifetime_)
: config(config_)
, function_deregister(std::move(function_deregister_))
, lifetime(lifetime_)
{
}

View File

@ -2,6 +2,8 @@
#include <string>
#include <common/scope_guard.h>
#include <DataTypes/IDataType.h>
#include <Interpreters/IExternalLoadable.h>
@ -23,6 +25,7 @@ public:
UserDefinedExecutableFunction(
const Config & config_,
std::shared_ptr<scope_guard> function_deregister_,
const ExternalLoadableLifetime & lifetime_);
const ExternalLoadableLifetime & getLifetime() const override
@ -47,7 +50,7 @@ public:
std::shared_ptr<const IExternalLoadable> clone() const override
{
return std::make_shared<UserDefinedExecutableFunction>(config, lifetime);
return std::make_shared<UserDefinedExecutableFunction>(config, function_deregister, lifetime);
}
const Config & getConfig() const
@ -67,6 +70,7 @@ public:
private:
Config config;
std::shared_ptr<scope_guard> function_deregister;
ExternalLoadableLifetime lifetime;
};

View File

@ -0,0 +1,86 @@
#include "UserDefinedExecutableFunctionFactory.h"
#include <Functions/FunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
namespace DB
{
namespace ErrorCodes
{
extern const int FUNCTION_ALREADY_EXISTS;
extern const int UNKNOWN_FUNCTION;
}
UserDefinedExecutableFunctionFactory & UserDefinedExecutableFunctionFactory::instance()
{
static UserDefinedExecutableFunctionFactory result;
return result;
}
void UserDefinedExecutableFunctionFactory::registerFunction(const String & function_name, Creator creator)
{
if (FunctionFactory::instance().hasNameOrAlias(function_name))
throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "The function '{}' already exists", function_name);
if (AggregateFunctionFactory::instance().hasNameOrAlias(function_name))
throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "The aggregate function '{}' already exists", function_name);
std::lock_guard lock(mutex);
auto [_, inserted] = function_name_to_creator.emplace(function_name, std::move(creator));
if (!inserted)
throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS,
"The function name '{}' is not unique",
function_name);
}
void UserDefinedExecutableFunctionFactory::unregisterFunction(const String & function_name)
{
std::lock_guard lock(mutex);
auto it = function_name_to_creator.find(function_name);
if (it == function_name_to_creator.end())
throw Exception(ErrorCodes::UNKNOWN_FUNCTION,
"The function name '{}' is not registered",
function_name);
}
FunctionOverloadResolverPtr UserDefinedExecutableFunctionFactory::get(const String & function_name, ContextPtr context) const
{
std::lock_guard lock(mutex);
auto it = function_name_to_creator.find(function_name);
if (it == function_name_to_creator.end())
throw Exception(ErrorCodes::UNKNOWN_FUNCTION,
"The function name '{}' is not registered",
function_name);
return it->second(context);
}
FunctionOverloadResolverPtr UserDefinedExecutableFunctionFactory::tryGet(const String & function_name, ContextPtr context) const
{
std::lock_guard lock(mutex);
auto it = function_name_to_creator.find(function_name);
if (it == function_name_to_creator.end())
return nullptr;
return it->second(context);
}
std::vector<String> UserDefinedExecutableFunctionFactory::getAllRegisteredNames() const
{
std::vector<std::string> registered_names;
std::lock_guard lock(mutex);
registered_names.reserve(function_name_to_creator.size());
for (const auto & [name, _] : function_name_to_creator)
registered_names.emplace_back(name);
return registered_names;
}
}

View File

@ -0,0 +1,38 @@
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <Common/NamePrompter.h>
#include <Interpreters/Context_fwd.h>
#include <Functions/IFunction.h>
namespace DB
{
class UserDefinedExecutableFunctionFactory : public IHints<1, UserDefinedExecutableFunctionFactory>
{
public:
using Creator = std::function<FunctionOverloadResolverPtr(ContextPtr)>;
static UserDefinedExecutableFunctionFactory & instance();
void registerFunction(const String & function_name, Creator creator);
void unregisterFunction(const String & function_name);
FunctionOverloadResolverPtr get(const String & function_name, ContextPtr context) const;
FunctionOverloadResolverPtr tryGet(const String & function_name, ContextPtr context) const;
std::vector<String> getAllRegisteredNames() const override;
private:
std::unordered_map<String, Creator> function_name_to_creator;
mutable std::mutex mutex;
};
}

View File

@ -1,39 +0,0 @@
#pragma once
#include <unordered_map>
#include <mutex>
#include <Common/NamePrompter.h>
#include <Parsers/ASTCreateFunctionQuery.h>
namespace DB
{
///Factory for user defined functions stores functions.
class UserDefinedFunctionFactory : public IHints<1, UserDefinedFunctionFactory>
{
public:
static UserDefinedFunctionFactory & instance();
/// Register function with function_name. create_function_query pointer must be ASTCreateFunctionQuery.
void registerFunction(const String & function_name, ASTPtr create_function_query);
/// Unregister function with function_name.
void unregisterFunction(const String & function_name);
/// Throws an exception if not found. Result ast pointer safely can be casted to ASTCreateFunctionQuery.
ASTPtr get(const String & function_name) const;
/// Returns nullptr if not found. Result ast pointer safely can be casted to ASTCreateFunctionQuery.
ASTPtr tryGet(const String & function_name) const;
/// Get all registered function names.
std::vector<String> getAllRegisteredNames() const override;
private:
std::unordered_map<String, ASTPtr> function_name_to_create_query;
mutable std::mutex mutex;
};
}

View File

@ -0,0 +1,15 @@
<functions>
<function>
<name>test_function</name>
<return_type>UInt64</return_type>
<argument>
<type>UInt64</type>
</argument>
<argument>
<type>UInt64</type>
</argument>
<format>TabSeparated</format>
<command>cd /; clickhouse-local --input-format TabSeparated --output-format TabSeparated --structure 'x UInt64, y UInt64' --query "SELECT x + y FROM table"</command>
<lifetime>0</lifetime>
</function>
</functions>

View File

@ -0,0 +1 @@
SELECT test_function(toUInt64(2), toUInt64(2));

View File

@ -134,6 +134,7 @@
<flush_interval_milliseconds>7500</flush_interval_milliseconds>
</query_log>
<dictionaries_config>*_dictionary.xml</dictionaries_config>
<user_defined_executable_functions_config>*_functions.xml</user_defined_executable_functions_config>
<compression incl="clickhouse_compression">
</compression>
<distributed_ddl>