From 997acdc39eb80650f9f755f93972e4d308a347ef Mon Sep 17 00:00:00 2001 From: Amos Bird Date: Fri, 20 Aug 2021 12:11:47 +0800 Subject: [PATCH] getPort function --- programs/server/Server.cpp | 1 + src/Functions/getServerPort.cpp | 136 ++++++++++++++++++ .../registerFunctionsMiscellaneous.cpp | 2 + src/Interpreters/Context.cpp | 16 +++ src/Interpreters/Context.h | 5 + .../02012_get_server_port.reference | 1 + .../0_stateless/02012_get_server_port.sql | 3 + 7 files changed, 164 insertions(+) create mode 100644 src/Functions/getServerPort.cpp create mode 100644 tests/queries/0_stateless/02012_get_server_port.reference create mode 100644 tests/queries/0_stateless/02012_get_server_port.sql diff --git a/programs/server/Server.cpp b/programs/server/Server.cpp index c30ef52f46a..5487361dac8 100644 --- a/programs/server/Server.cpp +++ b/programs/server/Server.cpp @@ -359,6 +359,7 @@ void Server::createServer(const std::string & listen_host, const char * port_nam try { func(port); + global_context->registerServerPort(port_name, port); } catch (const Poco::Exception &) { diff --git a/src/Functions/getServerPort.cpp b/src/Functions/getServerPort.cpp new file mode 100644 index 00000000000..8596bcd6a07 --- /dev/null +++ b/src/Functions/getServerPort.cpp @@ -0,0 +1,136 @@ +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int ILLEGAL_COLUMN; + extern const int ILLEGAL_TYPE_OF_ARGUMENT; +} + +namespace +{ + +class ExecutableFunctionGetServerPort : public IExecutableFunction +{ +public: + explicit ExecutableFunctionGetServerPort(UInt16 port_) : port(port_) {} + + String getName() const override { return "getServerPort"; } + + bool useDefaultImplementationForNulls() const override { return false; } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName &, const DataTypePtr &, size_t input_rows_count) const override + { + return DataTypeNumber().createColumnConst(input_rows_count, port); + } + +private: + UInt16 port; +}; + +class FunctionBaseGetServerPort : public IFunctionBase +{ +public: + explicit FunctionBaseGetServerPort(bool is_distributed_, UInt16 port_, DataTypes argument_types_, DataTypePtr return_type_) + : is_distributed(is_distributed_), port(port_), argument_types(std::move(argument_types_)), return_type(std::move(return_type_)) + { + } + + String getName() const override { return "getServerPort"; } + + const DataTypes & getArgumentTypes() const override + { + return argument_types; + } + + const DataTypePtr & getResultType() const override + { + return return_type; + } + + bool isDeterministic() const override { return false; } + bool isDeterministicInScopeOfQuery() const override { return true; } + bool isSuitableForConstantFolding() const override { return !is_distributed; } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } + + ExecutableFunctionPtr prepare(const ColumnsWithTypeAndName &) const override + { + return std::make_unique(port); + } + +private: + bool is_distributed; + UInt16 port; + DataTypes argument_types; + DataTypePtr return_type; +}; + +class GetServerPortOverloadResolver : public IFunctionOverloadResolver, WithContext +{ +public: + static constexpr auto name = "getServerPort"; + + String getName() const override { return name; } + + static FunctionOverloadResolverPtr create(ContextPtr context_) + { + return std::make_unique(context_); + } + + explicit GetServerPortOverloadResolver(ContextPtr context_) : WithContext(context_) {} + + size_t getNumberOfArguments() const override { return 1; } + ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {0}; } + bool isDeterministic() const override { return false; } + bool isDeterministicInScopeOfQuery() const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & data_types) const override + { + size_t number_of_arguments = data_types.size(); + if (number_of_arguments != 1) + throw Exception( + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Number of arguments for function {} doesn't match: passed {}, should be 1", + getName(), + number_of_arguments); + return std::make_shared>(); + } + + FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override + { + if (!isString(arguments[0].type)) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "The argument of function {} should be a constant string with the name of a setting", + getName()); + const auto * column = arguments[0].column.get(); + if (!column || !checkAndGetColumnConstStringOrFixedString(column)) + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, + "The argument of function {} should be a constant string with the name of a setting", + getName()); + + String port_name{column->getDataAt(0)}; + auto port = getContext()->getServerPort(port_name); + + DataTypes argument_types; + argument_types.emplace_back(arguments.back().type); + return std::make_unique(getContext()->isDistributed(), port, argument_types, return_type); + } +}; + +} + +void registerFunctionGetServerPort(FunctionFactory & factory) +{ + factory.registerFunction(); +} + +} diff --git a/src/Functions/registerFunctionsMiscellaneous.cpp b/src/Functions/registerFunctionsMiscellaneous.cpp index 12c54aeeefd..cee2858dd80 100644 --- a/src/Functions/registerFunctionsMiscellaneous.cpp +++ b/src/Functions/registerFunctionsMiscellaneous.cpp @@ -71,6 +71,7 @@ void registerFunctionHasThreadFuzzer(FunctionFactory &); void registerFunctionInitializeAggregation(FunctionFactory &); void registerFunctionErrorCodeToName(FunctionFactory &); void registerFunctionTcpPort(FunctionFactory &); +void registerFunctionGetServerPort(FunctionFactory &); void registerFunctionByteSize(FunctionFactory &); void registerFunctionFile(FunctionFactory & factory); void registerFunctionConnectionId(FunctionFactory & factory); @@ -149,6 +150,7 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory) registerFunctionInitializeAggregation(factory); registerFunctionErrorCodeToName(factory); registerFunctionTcpPort(factory); + registerFunctionGetServerPort(factory); registerFunctionByteSize(factory); registerFunctionFile(factory); registerFunctionConnectionId(factory); diff --git a/src/Interpreters/Context.cpp b/src/Interpreters/Context.cpp index a634c19dcd6..5e5d0f1d1ca 100644 --- a/src/Interpreters/Context.cpp +++ b/src/Interpreters/Context.cpp @@ -227,6 +227,8 @@ struct ContextSharedPart ConfigurationPtr clusters_config; /// Stores updated configs mutable std::mutex clusters_mutex; /// Guards clusters and clusters_config + std::map server_ports; + bool shutdown_called = false; Stopwatch uptime_watch; @@ -1816,6 +1818,20 @@ std::optional Context::getTCPPortSecure() const return {}; } +void Context::registerServerPort(String port_name, UInt16 port) +{ + shared->server_ports.emplace(std::move(port_name), port); +} + +UInt16 Context::getServerPort(const String & port_name) const +{ + auto it = shared->server_ports.find(port_name); + if (it == shared->server_ports.end()) + throw Exception(ErrorCodes::BAD_GET, "There is no port named {}", port_name); + else + return it->second; +} + std::shared_ptr Context::getCluster(const std::string & cluster_name) const { auto res = getClusters()->getCluster(cluster_name); diff --git a/src/Interpreters/Context.h b/src/Interpreters/Context.h index 4e378dacf01..e86650f958c 100644 --- a/src/Interpreters/Context.h +++ b/src/Interpreters/Context.h @@ -584,6 +584,11 @@ public: std::optional getTCPPortSecure() const; + /// Register server ports during server starting up. No lock is held. + void registerServerPort(String port_name, UInt16 port); + + UInt16 getServerPort(const String & port_name) const; + /// For methods below you may need to acquire the context lock by yourself. ContextMutablePtr getQueryContext() const; diff --git a/tests/queries/0_stateless/02012_get_server_port.reference b/tests/queries/0_stateless/02012_get_server_port.reference new file mode 100644 index 00000000000..d58c55a31dc --- /dev/null +++ b/tests/queries/0_stateless/02012_get_server_port.reference @@ -0,0 +1 @@ +9000 diff --git a/tests/queries/0_stateless/02012_get_server_port.sql b/tests/queries/0_stateless/02012_get_server_port.sql new file mode 100644 index 00000000000..cc7fecb0bf0 --- /dev/null +++ b/tests/queries/0_stateless/02012_get_server_port.sql @@ -0,0 +1,3 @@ +select getServerPort('tcp_port'); + +select getServerPort('unknown'); -- { serverError 170 }