From 41416952c83aa83c588b7eed6b9e83ee0b8d01ed Mon Sep 17 00:00:00 2001 From: Gamezardashvili George Date: Tue, 19 Nov 2024 14:58:50 +0000 Subject: [PATCH] SSH Authentication --- base/glibc-compatibility/CMakeLists.txt | 5 + base/glibc-compatibility/musl/libc.c | 10 + base/glibc-compatibility/musl/libc.h | 72 +++ base/glibc-compatibility/musl/openpty.c | 40 ++ base/glibc-compatibility/musl/pty.c | 34 ++ base/glibc-compatibility/musl/pty/pty.h | 15 + .../poco/Util/include/Poco/Util/Application.h | 2 +- base/poco/Util/src/Application.cpp | 18 +- docs/en/interfaces/ssh.md | 9 + programs/client/Client.cpp | 45 +- programs/keeper-client/KeeperClient.cpp | 6 +- programs/local/LocalServer.cpp | 7 +- programs/server/Server.cpp | 44 ++ programs/server/Server.h | 3 + src/Access/AuthenticationData.cpp | 7 +- src/Access/Credentials.h | 1 + src/Access/SSH/SSHPublicKey.cpp | 154 ++++++ src/Access/SSH/SSHPublicKey.h | 67 +++ src/CMakeLists.txt | 7 + src/Client/ClientApplicationBase.h | 2 + src/Client/ClientBase.cpp | 119 +++- src/Client/ClientBase.h | 23 +- src/Client/Connection.cpp | 2 +- src/Client/ConnectionParameters.cpp | 38 +- src/Client/ConnectionParameters.h | 10 +- src/Client/LineReader.cpp | 10 +- src/Client/LocalConnection.cpp | 32 +- src/Client/LocalConnection.h | 15 +- src/Client/Suggest.cpp | 2 +- src/Common/Config/ConfigHelper.cpp | 5 + src/Common/Config/ConfigHelper.h | 2 + src/Common/ErrorCodes.cpp | 5 +- src/Common/LibSSHInitializer.cpp | 61 ++ src/Common/LibSSHInitializer.h | 20 + src/Common/LibSSHLogger.cpp | 62 +++ src/Common/LibSSHLogger.h | 8 + src/Common/ProgressIndication.h | 1 + src/Common/SSHWrapper.h | 5 +- src/Common/clibssh.h | 13 + src/Interpreters/Context.h | 1 + src/Server/ClientEmbedded/ClientEmbedded.cpp | 193 +++++++ src/Server/ClientEmbedded/ClientEmbedded.h | 77 +++ .../ClientEmbedded/ClientEmbeddedRunner.cpp | 64 +++ .../ClientEmbedded/ClientEmbeddedRunner.h | 52 ++ .../ClientEmbedded/IClientDescriptorSet.h | 39 ++ .../ClientEmbedded/PipeClientDescriptorSet.h | 64 +++ .../ClientEmbedded/PtyClientDescriptorSet.cpp | 77 +++ .../ClientEmbedded/PtyClientDescriptorSet.h | 68 +++ src/Server/SSH/SSHBind.cpp | 79 +++ src/Server/SSH/SSHBind.h | 54 ++ src/Server/SSH/SSHChannel.cpp | 87 +++ src/Server/SSH/SSHChannel.h | 50 ++ src/Server/SSH/SSHEvent.cpp | 93 ++++ src/Server/SSH/SSHEvent.h | 48 ++ src/Server/SSH/SSHPtyHandler.cpp | 519 ++++++++++++++++++ src/Server/SSH/SSHPtyHandler.h | 43 ++ src/Server/SSH/SSHPtyHandlerFactory.h | 120 ++++ src/Server/SSH/SSHSession.cpp | 124 +++++ src/Server/SSH/SSHSession.h | 60 ++ src/Server/ServerType.cpp | 4 + src/Server/ServerType.h | 1 + src/Server/TCPServer.h | 2 + tests/integration/test_ssh/__init__.py | 0 tests/integration/test_ssh/configs/server.xml | 6 + tests/integration/test_ssh/configs/users.xml | 12 + tests/integration/test_ssh/keys/lucy_ed25519 | 7 + .../test_ssh/keys/lucy_ed25519.pub | 1 + .../test_ssh/keys/ssh_host_rsa_key | 49 ++ .../test_ssh/keys/ssh_host_rsa_key.pub | 1 + tests/integration/test_ssh/test.py | 45 ++ utils/check-style/check-style | 5 +- 71 files changed, 2946 insertions(+), 80 deletions(-) create mode 100644 base/glibc-compatibility/musl/libc.c create mode 100644 base/glibc-compatibility/musl/libc.h create mode 100644 base/glibc-compatibility/musl/openpty.c create mode 100644 base/glibc-compatibility/musl/pty.c create mode 100644 base/glibc-compatibility/musl/pty/pty.h create mode 100644 docs/en/interfaces/ssh.md create mode 100644 src/Access/SSH/SSHPublicKey.cpp create mode 100644 src/Access/SSH/SSHPublicKey.h create mode 100644 src/Common/LibSSHInitializer.cpp create mode 100644 src/Common/LibSSHInitializer.h create mode 100644 src/Common/LibSSHLogger.cpp create mode 100644 src/Common/LibSSHLogger.h create mode 100644 src/Common/clibssh.h create mode 100644 src/Server/ClientEmbedded/ClientEmbedded.cpp create mode 100644 src/Server/ClientEmbedded/ClientEmbedded.h create mode 100644 src/Server/ClientEmbedded/ClientEmbeddedRunner.cpp create mode 100644 src/Server/ClientEmbedded/ClientEmbeddedRunner.h create mode 100644 src/Server/ClientEmbedded/IClientDescriptorSet.h create mode 100644 src/Server/ClientEmbedded/PipeClientDescriptorSet.h create mode 100644 src/Server/ClientEmbedded/PtyClientDescriptorSet.cpp create mode 100644 src/Server/ClientEmbedded/PtyClientDescriptorSet.h create mode 100644 src/Server/SSH/SSHBind.cpp create mode 100644 src/Server/SSH/SSHBind.h create mode 100644 src/Server/SSH/SSHChannel.cpp create mode 100644 src/Server/SSH/SSHChannel.h create mode 100644 src/Server/SSH/SSHEvent.cpp create mode 100644 src/Server/SSH/SSHEvent.h create mode 100644 src/Server/SSH/SSHPtyHandler.cpp create mode 100644 src/Server/SSH/SSHPtyHandler.h create mode 100644 src/Server/SSH/SSHPtyHandlerFactory.h create mode 100644 src/Server/SSH/SSHSession.cpp create mode 100644 src/Server/SSH/SSHSession.h create mode 100644 tests/integration/test_ssh/__init__.py create mode 100644 tests/integration/test_ssh/configs/server.xml create mode 100644 tests/integration/test_ssh/configs/users.xml create mode 100644 tests/integration/test_ssh/keys/lucy_ed25519 create mode 100644 tests/integration/test_ssh/keys/lucy_ed25519.pub create mode 100644 tests/integration/test_ssh/keys/ssh_host_rsa_key create mode 100644 tests/integration/test_ssh/keys/ssh_host_rsa_key.pub create mode 100644 tests/integration/test_ssh/test.py diff --git a/base/glibc-compatibility/CMakeLists.txt b/base/glibc-compatibility/CMakeLists.txt index a14f66ce22b..a1c42978daa 100644 --- a/base/glibc-compatibility/CMakeLists.txt +++ b/base/glibc-compatibility/CMakeLists.txt @@ -8,6 +8,10 @@ if (GLIBC_COMPATIBILITY) add_headers_and_sources(glibc_compatibility .) add_headers_and_sources(glibc_compatibility musl) + + # There are several symbols which exist only in musl. + set (musl_pty_include_dir musl/pty) + if (ARCH_AARCH64) list (APPEND glibc_compatibility_sources musl/aarch64/syscall.s musl/aarch64/longjmp.s) set (musl_arch_include_dir musl/aarch64) @@ -40,6 +44,7 @@ if (GLIBC_COMPATIBILITY) target_no_warning(glibc-compatibility unused-but-set-variable) target_no_warning(glibc-compatibility builtin-requires-header) + target_include_directories(glibc-compatibility PUBLIC ${musl_pty_include_dir}) target_include_directories(glibc-compatibility PRIVATE libcxxabi ${musl_arch_include_dir}) if (ENABLE_OPENSSL_DYNAMIC) diff --git a/base/glibc-compatibility/musl/libc.c b/base/glibc-compatibility/musl/libc.c new file mode 100644 index 00000000000..2e10942df1b --- /dev/null +++ b/base/glibc-compatibility/musl/libc.c @@ -0,0 +1,10 @@ +#include "libc.h" + +struct __libc __libc; + +size_t __hwcap; +size_t __sysinfo; +char *__progname=0, *__progname_full=0; + +weak_alias(__progname, program_invocation_short_name); +weak_alias(__progname_full, program_invocation_name); diff --git a/base/glibc-compatibility/musl/libc.h b/base/glibc-compatibility/musl/libc.h new file mode 100644 index 00000000000..5e14518312c --- /dev/null +++ b/base/glibc-compatibility/musl/libc.h @@ -0,0 +1,72 @@ +#ifndef LIBC_H +#define LIBC_H + +#include +#include +#include + +struct __locale_map; + +struct __locale_struct { + const struct __locale_map *volatile cat[6]; +}; + +struct tls_module { + struct tls_module *next; + void *image; + size_t len, size, align, offset; +}; + +struct __libc { + int can_do_threads; + int threaded; + int secure; + volatile int threads_minus_1; + size_t *auxv; + struct tls_module *tls_head; + size_t tls_size, tls_align, tls_cnt; + size_t page_size; + struct __locale_struct global_locale; +}; + +#ifndef PAGE_SIZE +#define PAGE_SIZE libc.page_size +#endif + +#ifdef __PIC__ +#define ATTR_LIBC_VISIBILITY __attribute__((visibility("hidden"))) +#else +#define ATTR_LIBC_VISIBILITY +#endif + +extern struct __libc __libc ATTR_LIBC_VISIBILITY; +#define libc __libc + +extern size_t __hwcap ATTR_LIBC_VISIBILITY; +extern size_t __sysinfo ATTR_LIBC_VISIBILITY; +extern char *__progname, *__progname_full; + +/* Designed to avoid any overhead in non-threaded processes */ +void __lock(volatile int *) ATTR_LIBC_VISIBILITY; +void __unlock(volatile int *) ATTR_LIBC_VISIBILITY; +int __lockfile(FILE *) ATTR_LIBC_VISIBILITY; +void __unlockfile(FILE *) ATTR_LIBC_VISIBILITY; +#define LOCK(x) __lock(x) +#define UNLOCK(x) __unlock(x) + +void __synccall(void (*)(void *), void *); +int __setxid(int, int, int, int); + +extern char **__environ; + +#undef weak_alias +#define weak_alias(old, new) \ + extern __typeof(old) new __attribute__((weak, alias(#old))) + +#undef LFS64_2 +#define LFS64_2(x, y) weak_alias(x, y) + +#undef LFS64 +#define LFS64(x) LFS64_2(x, x##64) + +#endif diff --git a/base/glibc-compatibility/musl/openpty.c b/base/glibc-compatibility/musl/openpty.c new file mode 100644 index 00000000000..7b7381f09ee --- /dev/null +++ b/base/glibc-compatibility/musl/openpty.c @@ -0,0 +1,40 @@ +#include +#include +#include +#include "pty.h" +#include +#include + +/* Nonstandard, but vastly superior to the standard functions */ + +int openpty(int *pm, int *ps, char *name, const struct termios *tio, const struct winsize *ws) +{ + int m, s, n=0, cs; + char buf[20]; + + m = open("/dev/ptmx", O_RDWR|O_NOCTTY); + if (m < 0) return -1; + + pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs); + + if (ioctl(m, TIOCSPTLCK, &n) || ioctl (m, TIOCGPTN, &n)) + goto fail; + + if (!name) name = buf; + snprintf(name, sizeof buf, "/dev/pts/%d", n); + if ((s = open(name, O_RDWR|O_NOCTTY)) < 0) + goto fail; + + if (tio) tcsetattr(s, TCSANOW, tio); + if (ws) ioctl(s, TIOCSWINSZ, ws); + + *pm = m; + *ps = s; + + pthread_setcancelstate(cs, 0); + return 0; +fail: + close(m); + pthread_setcancelstate(cs, 0); + return -1; +} diff --git a/base/glibc-compatibility/musl/pty.c b/base/glibc-compatibility/musl/pty.c new file mode 100644 index 00000000000..b395d2c09e1 --- /dev/null +++ b/base/glibc-compatibility/musl/pty.c @@ -0,0 +1,34 @@ +#include +#include +#include +#include +#include +#include "libc.h" +#include "syscall.h" + +int posix_openpt(int flags) +{ + return open("/dev/ptmx", flags); +} + +int grantpt(int fd) +{ + return 0; +} + +int unlockpt(int fd) +{ + int unlock = 0; + return ioctl(fd, TIOCSPTLCK, &unlock); +} + +int __ptsname_r(int fd, char *buf, size_t len) +{ + int pty, err; + if (!buf) len = 0; + if ((err = __syscall(SYS_ioctl, fd, TIOCGPTN, &pty))) return -err; + if (snprintf(buf, len, "/dev/pts/%d", pty) >= len) return ERANGE; + return 0; +} + +weak_alias(__ptsname_r, ptsname_r); diff --git a/base/glibc-compatibility/musl/pty/pty.h b/base/glibc-compatibility/musl/pty/pty.h new file mode 100644 index 00000000000..7b45a5ed94f --- /dev/null +++ b/base/glibc-compatibility/musl/pty/pty.h @@ -0,0 +1,15 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include + +int openpty(int *, int *, char *, const struct termios *, const struct winsize *); +int forkpty(int *, char *, const struct termios *, const struct winsize *); + +#ifdef __cplusplus +} +#endif diff --git a/base/poco/Util/include/Poco/Util/Application.h b/base/poco/Util/include/Poco/Util/Application.h index 786e331fe73..6bf22a0de90 100644 --- a/base/poco/Util/include/Poco/Util/Application.h +++ b/base/poco/Util/include/Poco/Util/Application.h @@ -362,7 +362,7 @@ namespace Util void setArgs(int argc, char * argv[]); void setArgs(const ArgVec & args); void getApplicationPath(Poco::Path & path) const; - void processOptions(); + void processPocoOptions(); bool findAppConfigFile(const std::string & appName, const std::string & extension, Poco::Path & path) const; bool findAppConfigFile(const Path & basePath, const std::string & appName, const std::string & extension, Poco::Path & path) const; diff --git a/base/poco/Util/src/Application.cpp b/base/poco/Util/src/Application.cpp index 483315fda60..e4c8a1c160d 100644 --- a/base/poco/Util/src/Application.cpp +++ b/base/poco/Util/src/Application.cpp @@ -90,12 +90,12 @@ Application::~Application() void Application::setup() { poco_assert (_pInstance == 0); - + _pConfig->add(new SystemConfiguration, PRIO_SYSTEM, false, false); _pConfig->add(new MapConfiguration, PRIO_APPLICATION, true, false); - + addSubsystem(new LoggingSubsystem); - + #if defined(POCO_OS_FAMILY_UNIX) && !defined(POCO_VXWORKS) _workingDirAtLaunch = Path::current(); @@ -149,7 +149,7 @@ void Application::init() _pConfig->setString("application.cacheDir", Path::cacheHome() + appPath.getBaseName() + Path::separator()); _pConfig->setString("application.tempDir", Path::tempHome() + appPath.getBaseName() + Path::separator()); _pConfig->setString("application.dataDir", Path::dataHome() + appPath.getBaseName() + Path::separator()); - processOptions(); + processPocoOptions(); } @@ -169,7 +169,7 @@ void Application::initialize(Application& self) _initialized = true; } - + void Application::uninitialize() { if (_initialized) @@ -356,7 +356,7 @@ void Application::setArgs(int argc, char* argv[]) void Application::setArgs(const ArgVec& args) { poco_assert (!args.empty()); - + _command = args[0]; _pConfig->setInt("application.argc", (int) args.size()); _unprocessedArgs = args; @@ -368,7 +368,7 @@ void Application::setArgs(const ArgVec& args) } -void Application::processOptions() +void Application::processPocoOptions() { defineOptions(_options); OptionProcessor processor(_options); @@ -426,7 +426,7 @@ void Application::getApplicationPath(Poco::Path& appPath) const bool Application::findFile(Poco::Path& path) const { if (path.isAbsolute()) return true; - + Path appPath; getApplicationPath(appPath); Path base = appPath.parent(); @@ -472,7 +472,7 @@ bool Application::findAppConfigFile(const std::string& appName, const std::strin bool Application::findAppConfigFile(const Path& basePath, const std::string& appName, const std::string& extension, Path& path) const { poco_assert (!appName.empty()); - + Path p(basePath,appName); p.setExtension(extension); bool found = findFile(p); diff --git a/docs/en/interfaces/ssh.md b/docs/en/interfaces/ssh.md new file mode 100644 index 00000000000..6b3d89e0628 --- /dev/null +++ b/docs/en/interfaces/ssh.md @@ -0,0 +1,9 @@ +--- +slug: /en/interfaces/ssh +sidebar_position: 19 +sidebar_label: SSH Interface +--- + +# SSH Interface + +You can connect to clickhouse-server using any ssh client! diff --git a/programs/client/Client.cpp b/programs/client/Client.cpp index f78071b1278..161e4f86267 100644 --- a/programs/client/Client.cpp +++ b/programs/client/Client.cpp @@ -370,6 +370,8 @@ try showClientVersion(); } + default_database = config().getString("database", ""); + try { connect(); @@ -458,8 +460,11 @@ void Client::connect() { try { + const auto host = ConnectionParameters::Host{hosts_and_ports[attempted_address_index].host}; + const auto database = ConnectionParameters::Database{default_database}; + connection_parameters = ConnectionParameters( - config(), hosts_and_ports[attempted_address_index].host, hosts_and_ports[attempted_address_index].port); + config(), host, database, hosts_and_ports[attempted_address_index].port); if (is_interactive) std::cout << "Connecting to " @@ -508,6 +513,10 @@ void Client::connect() server_version = toString(server_version_major) + "." + toString(server_version_minor) + "." + toString(server_version_patch); load_suggestions = is_interactive && (server_revision >= Suggest::MIN_SERVER_REVISION) && !config().getBool("disable_suggestion", false); wait_for_suggestions_to_load = config().getBool("wait_for_suggestions_to_load", false); + if (load_suggestions) + { + suggestion_limit = config().getInt("suggestion_limit"); + } if (server_display_name = connection->getServerDisplayName(connection_parameters.timeouts); server_display_name.empty()) server_display_name = config().getString("host", "localhost"); @@ -1194,10 +1203,42 @@ void Client::processConfig() echo_queries = config().getBool("echo", false); ignore_error = config().getBool("ignore-error", false); - auto query_id = config().getString("query_id", ""); + query_id = config().getString("query_id", ""); if (!query_id.empty()) global_context->setCurrentQueryId(query_id); } + + if (is_interactive || delayed_interactive) + { + if (home_path.empty()) + { + const char * home_path_cstr = getenv("HOME"); // NOLINT(concurrency-mt-unsafe) + if (home_path_cstr) + home_path = home_path_cstr; + } + + /// Load command history if present. + if (config().has("history_file")) + history_file = config().getString("history_file"); + else + { + auto * history_file_from_env = getenv("CLICKHOUSE_HISTORY_FILE"); // NOLINT(concurrency-mt-unsafe) + if (history_file_from_env) + history_file = history_file_from_env; + else if (!home_path.empty()) + history_file = home_path + "/.clickhouse-client-history"; + } + } + + if (config().has("query")) + { + static_query = config().getRawString("query"); /// Poco configuration should not process substitutions in form of ${...} inside query. + } + + pager = config().getString("pager", ""); + + enable_highlight = config().getBool("highlight", true); + multiline = config().has("multiline"); print_stack_trace = config().getBool("stacktrace", false); pager = config().getString("pager", ""); diff --git a/programs/keeper-client/KeeperClient.cpp b/programs/keeper-client/KeeperClient.cpp index 782ac816150..f8175b0a0a2 100644 --- a/programs/keeper-client/KeeperClient.cpp +++ b/programs/keeper-client/KeeperClient.cpp @@ -2,8 +2,9 @@ #include "Commands.h" #include #include -#include "Common/VersionNumber.h" +#include #include +#include #include #include #include @@ -328,7 +329,8 @@ void KeeperClient::runInteractiveReplxx() query_extenders, query_delimiters, word_break_characters, - /* highlighter_= */ {}); + /* highlighter_= */ {} + ); lr.enableBracketedPaste(); while (true) diff --git a/programs/local/LocalServer.cpp b/programs/local/LocalServer.cpp index 3ecc6ecf24d..24089d4190d 100644 --- a/programs/local/LocalServer.cpp +++ b/programs/local/LocalServer.cpp @@ -488,7 +488,11 @@ void LocalServer::setupUsers() void LocalServer::connect() { - connection_parameters = ConnectionParameters(getClientConfiguration(), "localhost"); + connection_parameters = ConnectionParameters( + config(), + ConnectionParameters::Host{"localhost"}, + ConnectionParameters::Database{default_database} + ); /// This is needed for table function input(...). ReadBuffer * in; @@ -502,6 +506,7 @@ void LocalServer::connect() input = std::make_unique(table_file); in = input.get(); } + connection = LocalConnection::createConnection( connection_parameters, client_context, in, need_render_progress, need_render_profile_events, server_display_name); } diff --git a/programs/server/Server.cpp b/programs/server/Server.cpp index af383334128..4ca3dc351f0 100644 --- a/programs/server/Server.cpp +++ b/programs/server/Server.cpp @@ -128,6 +128,9 @@ #if USE_SSL # include # include +# include +# include +# include #endif #if USE_GRPC @@ -373,6 +376,16 @@ static std::string getCanonicalPath(std::string && path) return std::move(path); } + +Server::Server() +{ +#if USE_SSL + ::ssh::LibSSHInitializer::instance(); + ::ssh::libsshLogger::initialize(); +#endif +} + + Poco::Net::SocketAddress Server::socketBindListen( const Poco::Util::AbstractConfiguration & config, Poco::Net::ServerSocket & socket, @@ -2813,6 +2826,37 @@ void Server::createServers( }); } + if (server_type.shouldStart(ServerType::Type::TCP_SSH)) + { + port_name = "tcp_ssh_port"; + createServer( + config, + listen_host, + port_name, + listen_try, + start_servers, + servers, + [&](UInt16 port) -> ProtocolServerAdapter + { +#if USE_SSH + Poco::Net::ServerSocket socket; + auto address = socketBindListen(config, socket, listen_host, port, /* secure = */ false); + return ProtocolServerAdapter( + listen_host, + port_name, + "SSH pty: " + address.toString(), + std::make_unique( + new SSHPtyHandlerFactory(*this, config), + server_pool, + socket, + new Poco::Net::TCPServerParams)); +#else + UNUSED(port); + throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "SSH protocol is disabled for ClickHouse, as it has been built without libssh"); +#endif + }); + } + if (server_type.shouldStart(ServerType::Type::MYSQL)) { port_name = "mysql_port"; diff --git a/programs/server/Server.h b/programs/server/Server.h index feaf61f1ffd..0a874a02006 100644 --- a/programs/server/Server.h +++ b/programs/server/Server.h @@ -34,6 +34,9 @@ class ProtocolServerAdapter; class Server : public BaseDaemon, public IServer { public: + + Server(); + using ServerApplication::run; Poco::Util::LayeredConfiguration & config() const override diff --git a/src/Access/AuthenticationData.cpp b/src/Access/AuthenticationData.cpp index 37a4e356af8..7778a2ebd41 100644 --- a/src/Access/AuthenticationData.cpp +++ b/src/Access/AuthenticationData.cpp @@ -190,7 +190,12 @@ void AuthenticationData::setPasswordHashHex(const String & hash, bool validate) String AuthenticationData::getPasswordHashHex() const { - if (type == AuthenticationType::LDAP || type == AuthenticationType::KERBEROS || type == AuthenticationType::SSL_CERTIFICATE) + if ( + type == AuthenticationType::LDAP + || type == AuthenticationType::KERBEROS + || type == AuthenticationType::SSL_CERTIFICATE + || type == AuthenticationType::SSH_KEY + ) throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot get password hex hash for authentication type {}", toString(type)); String hex; diff --git a/src/Access/Credentials.h b/src/Access/Credentials.h index b21b7e6921f..ed995369969 100644 --- a/src/Access/Credentials.h +++ b/src/Access/Credentials.h @@ -125,4 +125,5 @@ private: }; #endif + } diff --git a/src/Access/SSH/SSHPublicKey.cpp b/src/Access/SSH/SSHPublicKey.cpp new file mode 100644 index 00000000000..73f438b1780 --- /dev/null +++ b/src/Access/SSH/SSHPublicKey.cpp @@ -0,0 +1,154 @@ +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int SSH_EXCEPTION; + extern const int LOGICAL_ERROR; + extern const int BAD_ARGUMENTS; +} + +} + +namespace ssh +{ + +SSHPublicKey::SSHPublicKey(KeyPtr key_, bool own) : key(key_, own ? &deleter : &disabledDeleter) +{ // disable deleter if class is constructed without ownership + if (!key) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "No ssh_key provided in explicit constructor"); + } +} + +SSHPublicKey::~SSHPublicKey() = default; + +SSHPublicKey::SSHPublicKey(const SSHPublicKey & other) : key(ssh_key_dup(other.key.get()), &deleter) +{ + if (!key) + { + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to duplicate ssh_key"); + } +} + +SSHPublicKey & SSHPublicKey::operator=(const SSHPublicKey & other) +{ + if (this != &other) + { + KeyPtr new_key = ssh_key_dup(other.key.get()); + if (!new_key) + { + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to duplicate ssh_key"); + } + key = UniqueKeyPtr(new_key, deleter); // We don't have access to the pointer from external code, opposed to non owning key object. + // So here we always go for default deleter, regardless of other's + } + return *this; +} + +SSHPublicKey::SSHPublicKey(SSHPublicKey && other) noexcept = default; + +SSHPublicKey & SSHPublicKey::operator=(SSHPublicKey && other) noexcept = default; + +bool SSHPublicKey::operator==(const SSHPublicKey & other) const +{ + return isEqual(other); +} + +bool SSHPublicKey::isEqual(const SSHPublicKey & other) const +{ + int rc = ssh_key_cmp(key.get(), other.key.get(), SSH_KEY_CMP_PUBLIC); + return rc == 0; +} + +SSHPublicKey SSHPublicKey::createFromBase64(const String & base64, const String & key_type) +{ + KeyPtr key; + int rc = ssh_pki_import_pubkey_base64(base64.c_str(), ssh_key_type_from_name(key_type.c_str()), &key); + if (rc != SSH_OK) + { + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed importing public key from base64 format.\n\ + Key: {}\n\ + Type: {}", + base64, key_type + ); + } + return SSHPublicKey(key); +} + +SSHPublicKey SSHPublicKey::createFromFile(const std::string & filename) +{ + KeyPtr key; + int rc = ssh_pki_import_pubkey_file(filename.c_str(), &key); + if (rc != SSH_OK) + { + if (rc == SSH_EOF) + { + throw DB::Exception( + DB::ErrorCodes::BAD_ARGUMENTS, + "Can't import ssh public key from file {} as it doesn't exist or permission denied", filename + ); + } + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Can't import ssh public key from file {}", filename); + } + return SSHPublicKey(key); +} + +SSHPublicKey SSHPublicKey::createNonOwning(KeyPtr key) +{ + return SSHPublicKey(key, false); +} + +namespace +{ + + struct CStringDeleter + { + [[maybe_unused]] void operator()(char * ptr) const { std::free(ptr); } + }; + +} + +String SSHPublicKey::getBase64Representation() const +{ + char * buf = nullptr; + int rc = ssh_pki_export_pubkey_base64(key.get(), &buf); + + if (rc != SSH_OK) + { + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to export public key to base64"); + } + + // Create a String from cstring, which makes a copy of the first one and requires freeing memory after it + std::unique_ptr buf_ptr(buf); // This is to safely manage buf memory + return String(buf_ptr.get()); +} + +String SSHPublicKey::getType() const +{ + const char * type_c = ssh_key_type_to_char(ssh_key_type(key.get())); + if (type_c == nullptr) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Key type is unknown or no key contained"); + } + return String(type_c); +} + +std::size_t SSHPublicKey::KeyHasher::operator()(const SSHPublicKey & input_key) const +{ + String combined_string(input_key.getType()); + combined_string += input_key.getBase64Representation(); + return string_hasher(combined_string); +} + +void SSHPublicKey::deleter(KeyPtr key) +{ + ssh_key_free(key); +} + +} diff --git a/src/Access/SSH/SSHPublicKey.h b/src/Access/SSH/SSHPublicKey.h new file mode 100644 index 00000000000..d67fa0d7954 --- /dev/null +++ b/src/Access/SSH/SSHPublicKey.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include +#include +#include + +struct ssh_key_struct; + +namespace ssh +{ + +class SSHPublicKey +{ +private: + class KeyHasher + { + public: + std::size_t operator()(const SSHPublicKey & input_key) const; + + private: + std::hash string_hasher; + }; + +public: + using KeyPtr = ssh_key_struct *; + using KeySet = std::unordered_set; + + SSHPublicKey() = delete; + ~SSHPublicKey(); + + SSHPublicKey(const SSHPublicKey &); + SSHPublicKey & operator=(const SSHPublicKey &); + + SSHPublicKey(SSHPublicKey &&) noexcept; + SSHPublicKey & operator=(SSHPublicKey &&) noexcept; + + bool operator==(const SSHPublicKey &) const; + + bool isEqual(const SSHPublicKey & other) const; + + String getBase64Representation() const; + + String getType() const; + + static SSHPublicKey createFromBase64(const String & base64, const String & key_type); + + static SSHPublicKey createFromFile(const String & filename); + + // Creates SSHPublicKey, but without owning the memory of ssh_key. + // A user must manage it by himself. (This is implemented for compatibility with libssh callbacks) + static SSHPublicKey createNonOwning(KeyPtr key); + +private: + explicit SSHPublicKey(KeyPtr key, bool own = true); + + static void deleter(KeyPtr key); + + // We may want to not own ssh_key memory, so then we pass this deleter to unique_ptr + static void disabledDeleter(KeyPtr) { } + + using UniqueKeyPtr = std::unique_ptr; + UniqueKeyPtr key; +}; + +} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3627d760d4c..30453881529 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -219,6 +219,9 @@ macro(add_object_library name common_path) endmacro() add_object_library(clickhouse_access Access) +if (TARGET ch_contrib::ssh) + add_object_library(clickhouse_access_ssh Access/SSH) +endif() add_object_library(clickhouse_backups Backups) add_object_library(clickhouse_core Core) add_object_library(clickhouse_core_mysql Core/MySQL) @@ -259,6 +262,10 @@ set_source_files_properties(Client/ClientBaseOptimizedParts.cpp PROPERTIES COMPI add_object_library(clickhouse_bridge BridgeHelper) add_object_library(clickhouse_server Server) add_object_library(clickhouse_server_http Server/HTTP) +if (TARGET ch_contrib::ssh) + add_object_library(clickhouse_server_ssh Server/SSH) +endif() +add_object_library(clickhouse_server_embedded_client Server/ClientEmbedded) add_object_library(clickhouse_formats Formats) add_object_library(clickhouse_processors Processors) add_object_library(clickhouse_processors_executors Processors/Executors) diff --git a/src/Client/ClientApplicationBase.h b/src/Client/ClientApplicationBase.h index 3663271dd25..5f956e0b500 100644 --- a/src/Client/ClientApplicationBase.h +++ b/src/Client/ClientApplicationBase.h @@ -47,6 +47,8 @@ protected: void setupSignalHandler() override; void addMultiquery(std::string_view query, Arguments & common_arguments) const; + virtual void readArguments(int argc, char ** argv, Arguments & common_arguments, std::vector &, std::vector &) = 0; + private: void parseAndCheckOptions(OptionsDescription & options_description, po::variables_map & options, Arguments & arguments); diff --git a/src/Client/ClientBase.cpp b/src/Client/ClientBase.cpp index c0f5744a4d5..9cc9ffb7b98 100644 --- a/src/Client/ClientBase.cpp +++ b/src/Client/ClientBase.cpp @@ -115,11 +115,14 @@ namespace Setting namespace ErrorCodes { extern const int BAD_ARGUMENTS; + extern const int DEADLOCK_AVOIDED; + extern const int DATABASE_ACCESS_DENIED; extern const int CLIENT_OUTPUT_FORMAT_SPECIFIED; extern const int UNKNOWN_PACKET_FROM_SERVER; extern const int NO_DATA_TO_INSERT; extern const int UNEXPECTED_PACKET_FROM_SERVER; + extern const int INCORRECT_FILE_NAME; extern const int INVALID_USAGE_OF_INPUT; extern const int CANNOT_SET_SIGNAL_HANDLER; extern const int LOGICAL_ERROR; @@ -544,6 +547,7 @@ try { if (!output_format) { + auto is_embedded = global_context->getApplicationType() == Context::ApplicationType::SERVER; /// Ignore all results when fuzzing as they can be huge. if (query_fuzzer_runs) { @@ -552,7 +556,7 @@ try } WriteBuffer * out_buf = nullptr; - if (!pager.empty()) + if (!pager.empty() && !is_embedded) { if (SIG_ERR == signal(SIGPIPE, SIG_IGN)) throw ErrnoException(ErrorCodes::CANNOT_SET_SIGNAL_HANDLER, "Cannot set signal handler for SIGPIPE"); @@ -583,8 +587,12 @@ try /// The query can specify output format or output file. if (const auto * query_with_output = dynamic_cast(parsed_query.get())) { + if (query_with_output->out_file && is_embedded) + { + error_stream << "Out files are disabled when you are running client embedded into server. Ignoring this option.\n"; + } String out_file; - if (query_with_output->out_file) + if (query_with_output->out_file && !is_embedded) { select_into_file = true; @@ -628,7 +636,7 @@ try { select_into_file_and_stdout = true; out_file_buf = std::make_unique(std::vector{std::move(out_file_buf), - std::make_shared(STDOUT_FILENO)}); + std::make_shared(out_fd)}); } // We are writing to file, so default format is the same as in non-interactive mode. @@ -642,7 +650,7 @@ try const auto & id = query_with_output->format->as(); current_format = id.name(); } - else if (query_with_output->out_file) + else if (query_with_output->out_file && !is_embedded) { auto format_name = FormatFactory::instance().tryGetFormatFromFileName(out_file); if (format_name) @@ -691,7 +699,7 @@ void ClientBase::initLogsOutputStream() if (server_logs_file.empty()) { /// Use stderr by default - out_logs_buf = std::make_unique>(STDERR_FILENO); + out_logs_buf = std::make_unique>(err_fd); wb = out_logs_buf.get(); color_logs = stderr_is_a_tty; } @@ -852,12 +860,22 @@ void ClientBase::initTTYBuffer(ProgressOption progress_option, ProgressOption pr /// use ProgressOption that was set for the progress bar for progress table as well. ProgressOption progress = progress_option ? progress_option : progress_table_option; - static constexpr auto tty_file_name = "/dev/tty"; - /// Output all progress bar commands to terminal at once to avoid flicker. /// This size is usually greater than the window size. static constexpr size_t buf_size = 1024; + // If we are embedded into server, there is no need to access terminal device via opening a file. + // Actually we need to pass tty's name, if we don't want this condition statement, + // because /dev/tty stands for controlling terminal of the process, thus a client will not see progress line. + // So it's easier to just pass a descriptor, without the terminal name. + if (global_context->getApplicationType() == Context::ApplicationType::SERVER) + { + tty_buf = std::make_unique(out_fd, buf_size); + return; + } + + static constexpr auto tty_file_name = "/dev/tty"; + if (is_interactive || progress == ProgressOption::TTY) { std::error_code ec; @@ -892,7 +910,7 @@ void ClientBase::initTTYBuffer(ProgressOption progress_option, ProgressOption pr if (stderr_is_a_tty || progress == ProgressOption::ERR) { - tty_buf = std::make_unique>(STDERR_FILENO, buf_size); + tty_buf = std::make_unique>(err_fd, buf_size); } else { @@ -1509,7 +1527,7 @@ void ClientBase::resetOutput() if (SIG_ERR == signal(SIGQUIT, SIG_DFL)) throw ErrnoException(ErrorCodes::CANNOT_SET_SIGNAL_HANDLER, "Cannot set signal handler for SIGQUIT"); - setupSignalHandler(); + // setupSignalHandler(); } pager_cmd = nullptr; @@ -1612,6 +1630,27 @@ void ClientBase::processInsertQuery(const String & query_to_execute, ASTPtr pars throw Exception(ErrorCodes::NO_DATA_TO_INSERT, "No data to insert"); return; } + // Validate infile before we pass further, as some files may be unsafe if client is embedded into server + if (global_context->getApplicationType() == Context::ApplicationType::SERVER && parsed_insert_query.infile) + { + const auto & in_file_node = parsed_insert_query.infile->as(); + const auto in_file = in_file_node.value.safeGet(); + String user_files_absolute_path = fs::weakly_canonical(global_context->getUserFilesPath()); + fs::path fs_table_path(in_file); + if (fs_table_path.is_relative()) + fs_table_path = user_files_absolute_path / fs_table_path; + + /// Do not use fs::canonical or fs::weakly_canonical. + /// Otherwise it will not allow to work with symlinks in `user_files_path` directory. + String path = fs::absolute(fs_table_path).lexically_normal(); /// Normalize path. + + auto table_path_stat = fs::status(path); + if (!fs::exists(table_path_stat)) + throw Exception(ErrorCodes::INCORRECT_FILE_NAME, "Provided file doesn't exist: {}", in_file); + + if (!fileOrSymlinkPathStartsWith(path, user_files_absolute_path)) + throw Exception(ErrorCodes::DATABASE_ACCESS_DENIED, "File `{}` is not inside `{}`", path, user_files_absolute_path); + } query_interrupt_handler.start(); SCOPE_EXIT({ query_interrupt_handler.stop(); }); @@ -1641,7 +1680,10 @@ void ClientBase::processInsertQuery(const String & query_to_execute, ASTPtr pars { /// If structure was received (thus, server has not thrown an exception), /// send our data with that structure. - setInsertionTable(parsed_insert_query); + if (global_context->getApplicationType() != Context::ApplicationType::SERVER) + { + setInsertionTable(parsed_insert_query); + } sendData(sample, columns_description, parsed_query); receiveEndOfQuery(); @@ -2122,6 +2164,8 @@ void ClientBase::processParsedSingleQuery(const String & full_query, const Strin { const String & new_database = use_query->getDatabase(); /// If the client initiates the reconnection, it takes the settings from the config. + /// TODO: Revisit + default_database = new_database; getClientConfiguration().setString("database", new_database); /// If the connection initiates the reconnection, it uses its variable. connection->setDefaultDatabase(new_database); @@ -2459,14 +2503,14 @@ bool ClientBase::executeMultiQuery(const String & all_queries_text) if (!server_exception) { error_matches_hint = false; - fmt::print(stderr, "Expected server error code '{}' but got no server error (query: {}).\n", + error_stream << fmt::format("Expected server error code '{}' but got no server error (query: {}).\n", test_hint.serverErrors(), full_query); } else if (!test_hint.hasExpectedServerError(server_exception->code())) { error_matches_hint = false; - fmt::print(stderr, "Expected server error code: {} but got: {} (query: {}).\n", - test_hint.serverErrors(), server_exception->code(), full_query); + error_stream << fmt::format("Expected server error code: {} but got: {} (query: {}).\n", + test_hint.serverErrors(), server_exception->code(), full_query); } } if (test_hint.hasClientErrors()) @@ -2474,13 +2518,13 @@ bool ClientBase::executeMultiQuery(const String & all_queries_text) if (!client_exception) { error_matches_hint = false; - fmt::print(stderr, "Expected client error code '{}' but got no client error (query: {}).\n", + error_stream << fmt::format("Expected client error code '{}' but got no client error (query: {}).\n", test_hint.clientErrors(), full_query); } else if (!test_hint.hasExpectedClientError(client_exception->code())) { error_matches_hint = false; - fmt::print(stderr, "Expected client error code '{}' but got '{}' (query: {}).\n", + error_stream << fmt::format("Expected client error code '{}' but got '{}' (query: {}).\n", test_hint.clientErrors(), client_exception->code(), full_query); } } @@ -2497,14 +2541,14 @@ bool ClientBase::executeMultiQuery(const String & all_queries_text) if (test_hint.hasClientErrors()) { error_matches_hint = false; - fmt::print(stderr, + error_stream << fmt::format( "The query succeeded but the client error '{}' was expected (query: {}).\n", test_hint.clientErrors(), full_query); } if (test_hint.hasServerErrors()) { error_matches_hint = false; - fmt::print(stderr, + error_stream << fmt::format( "The query succeeded but the server error '{}' was expected (query: {}).\n", test_hint.serverErrors(), full_query); } @@ -2737,33 +2781,54 @@ void ClientBase::runInteractive() LineReader::Patterns query_delimiters = {";", "\\G", "\\G;"}; char word_break_characters[] = " \t\v\f\a\b\r\n`~!@#$%^&*()-=+[{]}\\|;:'\",<.>/?"; + std::unique_ptr lr; + + #if USE_REPLXX replxx::Replxx::highlighter_callback_t highlight_callback{}; + if (getClientConfiguration().getBool("highlight", true)) highlight_callback = [this](const String & query, std::vector & colors) { highlight(query, colors, *client_context); }; - ReplxxLineReader lr( + String actual_history_file_path; + if (global_context->getApplicationType() != Context::ApplicationType::SERVER) + actual_history_file_path = history_file; + + lr = std::make_unique( *suggest, - history_file, + actual_history_file_path, history_max_entries, getClientConfiguration().has("multiline"), getClientConfiguration().getBool("ignore_shell_suspend", true), query_extenders, query_delimiters, word_break_characters, - highlight_callback); + highlight_callback, + input_stream, + output_stream, + in_fd, + out_fd, + err_fd + ); #else - (void)word_break_characters; - LineReader lr( + lr = LineReader( history_file, getClientConfiguration().has("multiline"), query_extenders, - query_delimiters); + query_delimiters, + word_break_characters, + input_stream, + output_stream, + in_fd + ); #endif + /// Enable bracketed-paste-mode so that we are able to paste multiline queries as a whole. + lr->enableBracketedPaste(); + static const std::initializer_list> backslash_aliases = { { "\\l", "SHOW DATABASES" }, @@ -2787,10 +2852,10 @@ void ClientBase::runInteractive() /// But keep it disabled outside of query input, because it breaks password input /// (e.g. if we need to reconnect and show a password prompt). /// (Alternatively, we could make the password input ignore the control sequences.) - lr.enableBracketedPaste(); - SCOPE_EXIT({ lr.disableBracketedPaste(); }); + lr->enableBracketedPaste(); + SCOPE_EXIT({ lr->disableBracketedPaste(); }); - input = lr.readLine(prompt(), ":-] "); + input = lr->readLine(prompt(), ":-] "); } if (input.empty()) @@ -2929,7 +2994,7 @@ void ClientBase::runNonInteractive() { /// If 'query' parameter is not set, read a query from stdin. /// The query is read entirely into memory (streaming is disabled). - ReadBufferFromFileDescriptor in(STDIN_FILENO); + ReadBufferFromFileDescriptor in(in_fd); String text; readStringUntilEOF(text, in); if (query_fuzzer_runs) diff --git a/src/Client/ClientBase.h b/src/Client/ClientBase.h index 5fa7006b313..6b1a5172901 100644 --- a/src/Client/ClientBase.h +++ b/src/Client/ClientBase.h @@ -70,6 +70,7 @@ enum ProgressOption ProgressOption toProgressOption(std::string progress); std::istream& operator>> (std::istream & in, ProgressOption & progress); + class InternalTextLogs; class TerminalKeystrokeInterceptor; class WriteBufferFromFileDescriptor; @@ -132,6 +133,8 @@ protected: static void adjustQueryEnd(const char *& this_query_end, const char * all_queries_end, uint32_t max_parser_depth, uint32_t max_parser_backtracks); virtual void setupSignalHandler() = 0; + ASTPtr parseQuery(const char *& pos, const char * end, bool allow_multi_statements) const; + bool executeMultiQuery(const String & all_queries_text); MultiQueryProcessingStage analyzeMultiQueryText( const char *& this_query_begin, const char *& this_query_end, const char * all_queries_end, @@ -163,13 +166,6 @@ protected: /// Returns true if query processing was successful. bool processQueryText(const String & text); - virtual void readArguments( - int argc, - char ** argv, - Arguments & common_arguments, - std::vector & external_tables_arguments, - std::vector & hosts_and_ports_arguments) = 0; - void setInsertionTable(const ASTInsertQuery & insert_query); private: @@ -223,6 +219,7 @@ protected: void start(Int32 signals_before_stop = 1) { exit_after_signals.store(signals_before_stop); } /// Set value not greater then 0 to mark the query as stopped. + void stop() { exit_after_signals.store(0); } /// Return true if the query was stopped. @@ -257,12 +254,19 @@ protected: /// Should be one of the first, to be destroyed the last, /// since other members can use them. - SharedContextHolder shared_context; + SharedContextHolder shared_context; // maybe not initialized ContextMutablePtr global_context; /// Client context is a context used only by the client to parse queries, process query parameters and to connect to clickhouse-server. ContextMutablePtr client_context; + String default_database; + String query_id; + Int32 suggestion_limit; + bool enable_highlight = true; + bool multiline = false; + String static_query; + std::unique_ptr keystroke_interceptor; bool is_interactive = false; /// Use either interactive line editing interface or batch mode. @@ -307,7 +311,7 @@ protected: MergeTreeSettings cmd_merge_tree_settings; /// thread status should be destructed before shared context because it relies on process list. - std::optional thread_status; + std::optional thread_status; // may be not initialized in embedded client ServerConnectionPtr connection; ConnectionParameters connection_parameters; @@ -328,6 +332,7 @@ protected: std::unique_ptr logs_out_stream; /// /dev/tty if accessible or std::cerr - for progress bar. + /// But running embedded into server, we write the progress to given tty file dexcriptor. /// We prefer to output progress bar directly to tty to allow user to redirect stdout and stderr and still get the progress indication. std::unique_ptr tty_buf; std::mutex tty_mutex; diff --git a/src/Client/Connection.cpp b/src/Client/Connection.cpp index ace3c2fe9af..5b2c943d348 100644 --- a/src/Client/Connection.cpp +++ b/src/Client/Connection.cpp @@ -11,7 +11,7 @@ #include #include #include -#include +#include #include #include #include "Common/logger_useful.h" diff --git a/src/Client/ConnectionParameters.cpp b/src/Client/ConnectionParameters.cpp index 4d0a9ffa08c..9be657b8fb6 100644 --- a/src/Client/ConnectionParameters.cpp +++ b/src/Client/ConnectionParameters.cpp @@ -12,7 +12,6 @@ #include - namespace DB { @@ -39,15 +38,32 @@ bool enableSecureConnection(const Poco::Util::AbstractConfiguration & config, co } -ConnectionParameters::ConnectionParameters(const Poco::Util::AbstractConfiguration & config, - std::string connection_host, - std::optional connection_port) - : host(connection_host) - , port(connection_port.value_or(getPortFromConfig(config, connection_host))) +ConnectionParameters ConnectionParameters::createForEmbedded(const String & user, const String & database) { - security = enableSecureConnection(config, connection_host) ? Protocol::Secure::Enable : Protocol::Secure::Disable; + auto connection_params = ConnectionParameters(); + connection_params.host = "localhost"; + connection_params.security = Protocol::Secure::Disable; + connection_params.password = ""; + connection_params.user = user; + connection_params.default_database = database; + connection_params.compression = Protocol::Compression::Disable; - default_database = config.getString("database", ""); + // TODO: Pass settings struct. + // connection_params.timeouts = ConnectionTimeouts::getTCPTimeoutsWithFailover(getGlobal); + + connection_params.timeouts.sync_request_timeout = Poco::Timespan(DBMS_DEFAULT_SYNC_REQUEST_TIMEOUT_SEC, 0); + return connection_params; +} + +ConnectionParameters::ConnectionParameters(const Poco::Util::AbstractConfiguration & config, + const Host & host_, + const Database & database, + std::optional port_) + : host(host_) + , port(port_.value_or(getPortFromConfig(config, host_))) + , default_database(database) +{ + security = enableSecureConnection(config, host_) ? Protocol::Secure::Enable : Protocol::Secure::Disable; /// changed the default value to "default" to fix the issue when the user in the prompt is blank user = config.getString("user", "default"); @@ -139,10 +155,10 @@ ConnectionParameters::ConnectionParameters(const Poco::Util::AbstractConfigurati Poco::Timespan(config.getInt("sync_request_timeout", DBMS_DEFAULT_SYNC_REQUEST_TIMEOUT_SEC), 0)); } -ConnectionParameters::ConnectionParameters(const Poco::Util::AbstractConfiguration & config, - std::string connection_host) - : ConnectionParameters(config, config.getString("host", "localhost"), getPortFromConfig(config, connection_host)) +ConnectionParameters::ConnectionParameters(const Poco::Util::AbstractConfiguration & config_, const Host & host_, const Database & database_) + : ConnectionParameters(config_, host_, database_, getPortFromConfig(config_, host_)) { + } UInt16 ConnectionParameters::getPortFromConfig(const Poco::Util::AbstractConfiguration & config, diff --git a/src/Client/ConnectionParameters.h b/src/Client/ConnectionParameters.h index 85e5fcb0ce1..6ad85288586 100644 --- a/src/Client/ConnectionParameters.h +++ b/src/Client/ConnectionParameters.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -29,15 +30,20 @@ struct ConnectionParameters Protocol::Compression compression = Protocol::Compression::Enable; ConnectionTimeouts timeouts; + using Database = StrongTypedef; + using Host = StrongTypedef; + ConnectionParameters() = default; - ConnectionParameters(const Poco::Util::AbstractConfiguration & config, String host); - ConnectionParameters(const Poco::Util::AbstractConfiguration & config, String host, std::optional port); + ConnectionParameters(const Poco::Util::AbstractConfiguration & config_, const Host & host_, const Database & database_); + ConnectionParameters(const Poco::Util::AbstractConfiguration & config_, const Host & host_, const Database & database_, std::optional port_); static UInt16 getPortFromConfig(const Poco::Util::AbstractConfiguration & config, const std::string & connection_host); /// Ask to enter the user's password if password option contains this value. /// "\n" is used because there is hardly a chance that a user would use '\n' as password. static constexpr std::string_view ASK_PASSWORD = "\n"; + + static ConnectionParameters createForEmbedded(const String & user, const String & database); }; } diff --git a/src/Client/LineReader.cpp b/src/Client/LineReader.cpp index e077343ada3..a10b51a1370 100644 --- a/src/Client/LineReader.cpp +++ b/src/Client/LineReader.cpp @@ -59,8 +59,11 @@ namespace DB /// Allows delaying the start of query execution until the entirety of query is inserted. bool LineReader::hasInputData() const { - pollfd fd{in_fd, POLLIN, 0}; - return poll(&fd, 1, 0) == 1; + timeval timeout = {0, 0}; + fd_set fds{}; + FD_ZERO(&fds); + FD_SET(in_fd, &fds); + return select(1, &fds, nullptr, nullptr, &timeout) == 1; } replxx::Replxx::completions_t LineReader::Suggest::getCompletions(const String & prefix, size_t prefix_length, const char * word_break_characters) @@ -131,7 +134,8 @@ void LineReader::Suggest::addWords(Words && new_words) // NOLINT(cppcoreguidelin } } -LineReader::LineReader( +LineReader::LineReader +( const String & history_file_path_, bool multiline_, Patterns extenders_, diff --git a/src/Client/LocalConnection.cpp b/src/Client/LocalConnection.cpp index bb36d0bbf39..a72a60eac16 100644 --- a/src/Client/LocalConnection.cpp +++ b/src/Client/LocalConnection.cpp @@ -1,6 +1,7 @@ #include "LocalConnection.h" #include #include +#include #include #include #include @@ -46,15 +47,25 @@ namespace ErrorCodes LocalConnection::LocalConnection(ContextPtr context_, ReadBuffer * in_, bool send_progress_, bool send_profile_events_, const String & server_display_name_) : WithContext(context_) - , session(getContext(), ClientInfo::Interface::LOCAL) + , session(std::make_unique(getContext(), ClientInfo::Interface::LOCAL)) , send_progress(send_progress_) , send_profile_events(send_profile_events_) , server_display_name(server_display_name_) , in(in_) { /// Authenticate and create a context to execute queries. - session.authenticate("default", "", Poco::Net::SocketAddress{}); - session.makeSessionContext(); + session->authenticate("default", "", Poco::Net::SocketAddress{}); + session->makeSessionContext(); +} + +LocalConnection::LocalConnection( + std::unique_ptr && session_, bool send_progress_, bool send_profile_events_, const String & server_display_name_) + : WithContext(session_->sessionContext()) + , session(std::move(session_)) + , send_progress(send_progress_) + , send_profile_events(send_profile_events_) + , server_display_name(server_display_name_) +{ } LocalConnection::~LocalConnection() @@ -115,9 +126,9 @@ void LocalConnection::sendQuery( /// Suggestion comes without client_info. if (client_info) - query_context = session.makeQueryContext(*client_info); + query_context = session->makeQueryContext(*client_info); else - query_context = session.makeQueryContext(); + query_context = session->makeQueryContext(); query_context->setCurrentQueryId(query_id); if (send_progress) @@ -169,6 +180,7 @@ void LocalConnection::sendQuery( const auto & settings = context->getSettingsRef(); const char * begin = state->query.data(); + const char * end = begin + state->query.size(); const Dialect & dialect = settings[Setting::dialect]; @@ -675,5 +687,15 @@ ServerConnectionPtr LocalConnection::createConnection( return std::make_unique(current_context, in, send_progress, send_profile_events, server_display_name); } +ServerConnectionPtr LocalConnection::createConnection( + const ConnectionParameters &, + std::unique_ptr && session, + bool send_progress, + bool send_profile_events, + const String & server_display_name) +{ + return std::make_unique(std::move(session), send_progress, send_profile_events, server_display_name); +} + } diff --git a/src/Client/LocalConnection.h b/src/Client/LocalConnection.h index c605b37b075..1876b1326d1 100644 --- a/src/Client/LocalConnection.h +++ b/src/Client/LocalConnection.h @@ -76,6 +76,12 @@ public: bool send_profile_events_, const String & server_display_name_); + explicit LocalConnection( + std::unique_ptr && session_, + bool send_progress_ = false, + bool send_profile_events_ = false, + const String & server_display_name_ = ""); + ~LocalConnection() override; IServerConnection::Type getConnectionType() const override { return IServerConnection::Type::LOCAL; } @@ -88,6 +94,13 @@ public: bool send_profile_events = false, const String & server_display_name = ""); + static ServerConnectionPtr createConnection( + const ConnectionParameters & connection_parameters, + std::unique_ptr && session, + bool send_progress = false, + bool send_profile_events = false, + const String & server_display_name = ""); + void setDefaultDatabase(const String & database) override; void getServerVersion(const ConnectionTimeouts & timeouts, @@ -160,7 +173,7 @@ private: bool needSendProgressOrMetrics(); ContextMutablePtr query_context; - Session session; + std::unique_ptr session; bool send_progress; bool send_profile_events; diff --git a/src/Client/Suggest.cpp b/src/Client/Suggest.cpp index e8f5409a009..1e2f89fa135 100644 --- a/src/Client/Suggest.cpp +++ b/src/Client/Suggest.cpp @@ -92,7 +92,7 @@ static String getLoadSuggestionQuery(Int32 suggestion_limit, bool basic_suggesti template void Suggest::load(ContextPtr context, const ConnectionParameters & connection_parameters, Int32 suggestion_limit, bool wait_for_load) { - loading_thread = std::thread([my_context = Context::createCopy(context), connection_parameters, suggestion_limit, this] + loading_thread = std::thread([my_context=Context::createCopy(context), connection_parameters, suggestion_limit, this] { ThreadStatus thread_status; for (size_t retry = 0; retry < 10; ++retry) diff --git a/src/Common/Config/ConfigHelper.cpp b/src/Common/Config/ConfigHelper.cpp index 6de63fe78d7..99d3b27d360 100644 --- a/src/Common/Config/ConfigHelper.cpp +++ b/src/Common/Config/ConfigHelper.cpp @@ -36,6 +36,11 @@ Poco::AutoPtr clone(const Poco::Util::Abstrac return res; } +Poco::AutoPtr createEmpty() +{ + return new Poco::Util::XMLConfiguration(); +} + bool getBool(const Poco::Util::AbstractConfiguration & config, const std::string & key, bool default_, bool empty_as) { if (!config.has(key)) diff --git a/src/Common/Config/ConfigHelper.h b/src/Common/Config/ConfigHelper.h index 513438bd859..f52ba6a3937 100644 --- a/src/Common/Config/ConfigHelper.h +++ b/src/Common/Config/ConfigHelper.h @@ -18,6 +18,8 @@ namespace DB::ConfigHelper /// (i.e. items like "value"). Poco::AutoPtr clone(const Poco::Util::AbstractConfiguration & src); +Poco::AutoPtr createEmpty(); + /// The behavior is like `config.getBool(key, default_)`, /// except when the tag is empty (aka. self-closing), `empty_as` will be used instead of throwing Poco::Exception. bool getBool(const Poco::Util::AbstractConfiguration & config, const std::string & key, bool default_ = false, bool empty_as = true); diff --git a/src/Common/ErrorCodes.cpp b/src/Common/ErrorCodes.cpp index 376ccf6f297..be9d6875b68 100644 --- a/src/Common/ErrorCodes.cpp +++ b/src/Common/ErrorCodes.cpp @@ -622,7 +622,8 @@ M(1000, POCO_EXCEPTION) \ M(1001, STD_EXCEPTION) \ M(1002, UNKNOWN_EXCEPTION) \ - /* See END */ + M(1003, SSH_EXCEPTION) \ +/* See END */ #ifdef APPLY_FOR_EXTERNAL_ERROR_CODES #define APPLY_FOR_ERROR_CODES(M) APPLY_FOR_BUILTIN_ERROR_CODES(M) APPLY_FOR_EXTERNAL_ERROR_CODES(M) @@ -638,7 +639,7 @@ namespace ErrorCodes APPLY_FOR_ERROR_CODES(M) #undef M - constexpr ErrorCode END = 1002; + constexpr ErrorCode END = 1003; ErrorPairHolder values[END + 1]{}; struct ErrorCodesNames diff --git a/src/Common/LibSSHInitializer.cpp b/src/Common/LibSSHInitializer.cpp new file mode 100644 index 00000000000..8346046d6d0 --- /dev/null +++ b/src/Common/LibSSHInitializer.cpp @@ -0,0 +1,61 @@ +#include "config.h" + +#include +#include + +#if USE_SSH + +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int SSH_EXCEPTION; +} + +} + +namespace ssh +{ + +LibSSHInitializer & LibSSHInitializer::instance() +{ + static LibSSHInitializer instance; + return instance; +} + +LibSSHInitializer::LibSSHInitializer() +{ + int rc = ssh_init(); + if (rc != SSH_OK) + { + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to initialize libssh"); + } +} + +LibSSHInitializer::~LibSSHInitializer() +{ + ssh_finalize(); +} + +} + +#else + +namespace ssh +{ + +LibSSHInitializer & LibSSHInitializer::instance() +{ + static LibSSHInitializer instance; + return instance; +} + +LibSSHInitializer::LibSSHInitializer() {} +LibSSHInitializer::~LibSSHInitializer() {} + +} + +#endif diff --git a/src/Common/LibSSHInitializer.h b/src/Common/LibSSHInitializer.h new file mode 100644 index 00000000000..00f682fe93d --- /dev/null +++ b/src/Common/LibSSHInitializer.h @@ -0,0 +1,20 @@ +#pragma once + +namespace ssh +{ + +class LibSSHInitializer +{ +public: + LibSSHInitializer(const LibSSHInitializer &) = delete; + LibSSHInitializer & operator=(const LibSSHInitializer &) = delete; + + static LibSSHInitializer & instance(); + + ~LibSSHInitializer(); + +private: + LibSSHInitializer(); // NOLINT +}; + +} diff --git a/src/Common/LibSSHLogger.cpp b/src/Common/LibSSHLogger.cpp new file mode 100644 index 00000000000..34e8c1fc0ce --- /dev/null +++ b/src/Common/LibSSHLogger.cpp @@ -0,0 +1,62 @@ +#include "config.h" + +#if USE_SSH +# include +# include +# include + +namespace ssh +{ + +namespace +{ + +void libssh_logger_callback(int priority, const char *, const char * buffer, void *) +{ + Poco::Logger * logger = &Poco::Logger::get("LibSSH"); + + switch (priority) + { + case SSH_LOG_NOLOG: + break; + case SSH_LOG_WARNING: + LOG_WARNING(logger, "{}", buffer); + break; + case SSH_LOG_PROTOCOL: + case SSH_LOG_PACKET: + case SSH_LOG_FUNCTIONS: + LOG_TRACE(logger, "{}", buffer); + break; + } +} + +} + +namespace libsshLogger +{ + +void initialize() +{ + ssh_set_log_callback(libssh_logger_callback); + ssh_set_log_level(SSH_LOG_FUNCTIONS); // Set the maximum log level +} + +} + +} + +#else + +namespace ssh +{ + +namespace libsshLogger +{ + +void initialize() {} + +} + +} + +#endif diff --git a/src/Common/LibSSHLogger.h b/src/Common/LibSSHLogger.h new file mode 100644 index 00000000000..1321a798660 --- /dev/null +++ b/src/Common/LibSSHLogger.h @@ -0,0 +1,8 @@ +#pragma once + +namespace ssh::libsshLogger +{ + +void initialize(); + +} diff --git a/src/Common/ProgressIndication.h b/src/Common/ProgressIndication.h index 6beadff1fc4..13010f2351e 100644 --- a/src/Common/ProgressIndication.h +++ b/src/Common/ProgressIndication.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include diff --git a/src/Common/SSHWrapper.h b/src/Common/SSHWrapper.h index b6f0c577edc..b3e63dae563 100644 --- a/src/Common/SSHWrapper.h +++ b/src/Common/SSHWrapper.h @@ -37,9 +37,12 @@ public: String getBase64() const; String getKeyType() const; + friend class SSHKeyFactory; -private: +// private: explicit SSHKey(ssh_key key_) : key(key_) { } + +private: ssh_key key = nullptr; }; diff --git a/src/Common/clibssh.h b/src/Common/clibssh.h new file mode 100644 index 00000000000..226f3bcb099 --- /dev/null +++ b/src/Common/clibssh.h @@ -0,0 +1,13 @@ +#pragma once +/* +Include this file to access libssh api. +*/ + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wreserved-macro-identifier" +#pragma GCC diagnostic ignored "-Wreserved-identifier" +#pragma GCC diagnostic ignored "-Wdocumentation" +#include // IWYU pragma: export +#include // IWYU pragma: export +#include "libssh/callbacks.h" // IWYU pragma: export +#pragma GCC diagnostic pop diff --git a/src/Interpreters/Context.h b/src/Interpreters/Context.h index 327ac0af5fd..b777e39fb08 100644 --- a/src/Interpreters/Context.h +++ b/src/Interpreters/Context.h @@ -1254,6 +1254,7 @@ public: { SERVER, /// The program is run as clickhouse-server daemon (default behavior) CLIENT, /// clickhouse-client + EMBEDDED_CLIENT,/// clickhouse-client being run over SSH tunnel LOCAL, /// clickhouse-local KEEPER, /// clickhouse-keeper (also daemon) DISKS, /// clickhouse-disks diff --git a/src/Server/ClientEmbedded/ClientEmbedded.cpp b/src/Server/ClientEmbedded/ClientEmbedded.cpp new file mode 100644 index 00000000000..e347e1ffcfa --- /dev/null +++ b/src/Server/ClientEmbedded/ClientEmbedded.cpp @@ -0,0 +1,193 @@ +#include + +#include +#include +#include +#include "Common/setThreadName.h" +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + + +namespace Setting +{ + extern const SettingsUInt64 max_insert_block_size; +} +namespace +{ + +template +T getEnvOption(const NameToNameMap & envVars, const String & key, T defaultValue) +{ + auto it = envVars.find(key); + return it == envVars.end() ? defaultValue : parse(it->second); +} + +} + + +void ClientEmbedded::processError(const String &) const +{ + if (ignore_error) + return; + + if (is_interactive) + { + String message; + if (server_exception) + { + message = getExceptionMessage(*server_exception, print_stack_trace, true); + } + else if (client_exception) + { + message = client_exception->message(); + } + + error_stream << fmt::format("Received exception\n{}\n\n", message); + } + else + { + if (server_exception) + server_exception->rethrow(); + if (client_exception) + client_exception->rethrow(); + } +} + + +void ClientEmbedded::cleanup() +{ + try + { + connection.reset(); + } + catch (...) + { + tryLogCurrentException(__PRETTY_FUNCTION__); + } +} + + +void ClientEmbedded::connect() +{ + if (!session) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Error creating connection without session object"); + } + connection_parameters = ConnectionParameters::createForEmbedded(session->sessionContext()->getUserName(), default_database); + connection = LocalConnection::createConnection( + connection_parameters, std::move(session), need_render_progress, need_render_profile_events, server_display_name); + if (!default_database.empty()) + { + connection->setDefaultDatabase(default_database); + } +} + +Poco::Util::LayeredConfiguration & ClientEmbedded::getClientConfiguration() +{ + chassert(layered_configuration); + return *layered_configuration; +} + + +int ClientEmbedded::run(const NameToNameMap & envVars, const String & first_query) +{ +try +{ + client_context = session->sessionContext(); + initClientContext(); + + setThreadName("LocalServerPty"); + + query_processing_stage = QueryProcessingStage::Enum::Complete; + + print_stack_trace = getEnvOption(envVars, "stacktrace", false); + + output_stream << std::fixed << std::setprecision(3); + error_stream << std::fixed << std::setprecision(3); + + is_interactive = stdin_is_a_tty; + static_query = first_query.empty() ? getEnvOption(envVars, "query", "") : first_query; + delayed_interactive = is_interactive && !static_query.empty(); + if (!is_interactive || delayed_interactive) + { + echo_queries = getEnvOption(envVars, "echo", false) || getEnvOption(envVars, "verbose", false); + ignore_error = getEnvOption(envVars, "ignore_error", false); + } + load_suggestions = (is_interactive || delayed_interactive) && !getEnvOption(envVars, "disable_suggestion", false); + if (load_suggestions) + { + suggestion_limit = getEnvOption(envVars, "suggestion_limit", 10000); + } + + + enable_highlight = getEnvOption(envVars, "highlight", true); + multiline = getEnvOption(envVars, "multiline", false); + + default_database = getEnvOption(envVars, "database", ""); + + default_output_format = getEnvOption(envVars, "output-format", getEnvOption(envVars, "format", is_interactive ? "PrettyCompact" : "TSV")); + // TODO: Fix + // insert_format = "Values"; + insert_format_max_block_size = getEnvOption(envVars, "insert_format_max_block_size", + global_context->getSettingsRef()[Setting::max_insert_block_size]); + + + server_display_name = getEnvOption(envVars, "display_name", getFQDNOrHostName()); + prompt_by_server_display_name = getEnvOption(envVars, "prompt_by_server_display_name", "{display_name} :) "); + std::map prompt_substitutions{{"display_name", server_display_name}}; + for (const auto & [key, value] : prompt_substitutions) + boost::replace_all(prompt_by_server_display_name, "{" + key + "}", value); + initTTYBuffer(toProgressOption(getEnvOption(envVars, "progress", "default")), + toProgressOption(getEnvOption(envVars, "progress-table", "default"))); + + if (is_interactive) + { + clearTerminal(); + showClientVersion(); + error_stream << std::endl; + } + + connect(); + + + if (is_interactive && !delayed_interactive) + { + runInteractive(); + } + else + { + runNonInteractive(); + + if (delayed_interactive) + runInteractive(); + } + + cleanup(); + return 0; +} +catch (const DB::Exception & e) +{ + cleanup(); + + error_stream << getExceptionMessage(e, print_stack_trace, true) << std::endl; + return e.code() ? e.code() : -1; +} +catch (...) +{ + cleanup(); + + error_stream << getCurrentExceptionMessage(false) << std::endl; + return getCurrentExceptionCode(); +} +} + + +} diff --git a/src/Server/ClientEmbedded/ClientEmbedded.h b/src/Server/ClientEmbedded/ClientEmbedded.h new file mode 100644 index 00000000000..92c76cf0ce9 --- /dev/null +++ b/src/Server/ClientEmbedded/ClientEmbedded.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + + +#include +#include +#include +#include + + +namespace DB +{ + +// Client class which can be run embedded into server +class ClientEmbedded : public ClientBase +{ +public: + explicit ClientEmbedded( + std::unique_ptr && session_, + int in_fd_, + int out_fd_, + int err_fd_, + std::istream & input_stream_, + std::ostream & output_stream_, + std::ostream & error_stream_) + : ClientBase(in_fd_, out_fd_, err_fd_, input_stream_, output_stream_, error_stream_), session(std::move(session_)) + { + global_context = session->makeSessionContext(); + configuration = ConfigHelper::createEmpty(); + layered_configuration = new Poco::Util::LayeredConfiguration(); + layered_configuration->add(configuration); + } + + int run(const NameToNameMap & envVars, const String & first_query); + + /// NOP + void setupSignalHandler() override {} + + ~ClientEmbedded() override { cleanup(); } + +protected: + void connect() override; + + Poco::Util::LayeredConfiguration & getClientConfiguration() override; + + void processError(const String & query) const override; + + String getName() const override { return "embedded"; } + + void printHelpMessage(const OptionsDescription &, bool) override {} + void addOptions(OptionsDescription &) override {} + void processOptions(const OptionsDescription &, + const CommandLineOptions &, + const std::vector &, + const std::vector &) override {} + void processConfig() override {} + +private: + void cleanup(); + + std::unique_ptr session; + + ConfigurationPtr configuration; + Poco::AutoPtr layered_configuration; +}; + +} diff --git a/src/Server/ClientEmbedded/ClientEmbeddedRunner.cpp b/src/Server/ClientEmbedded/ClientEmbeddedRunner.cpp new file mode 100644 index 00000000000..8a45d6e5203 --- /dev/null +++ b/src/Server/ClientEmbedded/ClientEmbeddedRunner.cpp @@ -0,0 +1,64 @@ +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +void ClientEmbeddedRunner::run(const NameToNameMap & envs, const String & starting_query) +{ + if (started.test_and_set()) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Client has been already started"); + } + LOG_DEBUG(log, "Starting client"); + client_thread = ThreadFromGlobalPool(&ClientEmbeddedRunner::clientRoutine, this, envs, starting_query); +} + + +void ClientEmbeddedRunner::changeWindowSize(int width, int height, int width_pixels, int height_pixels) +{ + auto * pty_descriptors = dynamic_cast(client_descriptors.get()); + if (pty_descriptors == nullptr) + { + throw Exception(ErrorCodes::LOGICAL_ERROR, "Accessing window change on non pty descriptors"); + } + pty_descriptors->changeWindowSize(width, height, width_pixels, height_pixels); +} + +ClientEmbeddedRunner::~ClientEmbeddedRunner() +{ + LOG_DEBUG(log, "Closing server descriptors and waiting for client to finish"); + client_descriptors->closeServerDescriptors(); // May throw if something bad happens to descriptors, which will call std::terminate + if (client_thread.joinable()) + { + client_thread.join(); + } + LOG_DEBUG(log, "Client has finished"); +} + +void ClientEmbeddedRunner::clientRoutine(NameToNameMap envs, String starting_query) +{ + try + { + auto descr = client_descriptors->getDescriptorsForClient(); + auto stre = client_descriptors->getStreamsForClient(); + ClientEmbedded client(std::move(db_session), descr.in, descr.out, descr.err, stre.in, stre.out, stre.err); + client.run(envs, starting_query); + } + catch (...) + { + tryLogCurrentException(__PRETTY_FUNCTION__); + } + finished.test_and_set(); + char c = 0; + // Server may poll on a descriptor waiting for client output, wake him up with invisible character + write(client_descriptors->getDescriptorsForClient().out, &c, 1); +} + +} diff --git a/src/Server/ClientEmbedded/ClientEmbeddedRunner.h b/src/Server/ClientEmbedded/ClientEmbeddedRunner.h new file mode 100644 index 00000000000..febf66b099b --- /dev/null +++ b/src/Server/ClientEmbedded/ClientEmbeddedRunner.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +// Runs embedded client in dedicated thread, passes descriptors, checks its state +class ClientEmbeddedRunner +{ +public: + bool hasStarted() { return started.test(); } + + bool hasFinished() { return finished.test(); } + + // void stopQuery() { client.stopQuery(); } // this is save for client until he uses thread-safe structures to handle query stopping + + void run(const NameToNameMap & envs, const String & starting_query = ""); + + IClientDescriptorSet::DescriptorSet getDescriptorsForServer() { return client_descriptors->getDescriptorsForServer(); } + + bool hasPty() const { return client_descriptors->isPty(); } + + // Sets new window size for tty. Works only if IClientDescriptorSet is pty + void changeWindowSize(int width, int height, int width_pixels, int height_pixels); + + ~ClientEmbeddedRunner(); + + explicit ClientEmbeddedRunner(std::unique_ptr && client_descriptor_, std::unique_ptr && dbSession_) + : client_descriptors(std::move(client_descriptor_)), db_session(std::move(dbSession_)), log(&Poco::Logger::get("ClientEmbeddedRunner")) + { + } + +private: + void clientRoutine(NameToNameMap envs, String starting_query); + + std::unique_ptr + client_descriptors; // This is used by server thread and client thread, be sure that server only gets them via getDescriptorsForServer + std::atomic_flag started = ATOMIC_FLAG_INIT; + std::atomic_flag finished = ATOMIC_FLAG_INIT; + + ThreadFromGlobalPool client_thread; + std::unique_ptr db_session; + Poco::Logger * log; +}; +} diff --git a/src/Server/ClientEmbedded/IClientDescriptorSet.h b/src/Server/ClientEmbedded/IClientDescriptorSet.h new file mode 100644 index 00000000000..1f819950b4b --- /dev/null +++ b/src/Server/ClientEmbedded/IClientDescriptorSet.h @@ -0,0 +1,39 @@ +#pragma once + +#include + +namespace DB +{ + + +// Interface, which handles descriptor pairs, that could be attached to embedded client +class IClientDescriptorSet +{ +public: + struct DescriptorSet + { + int in = -1; + int out = -1; + int err = -1; + }; + + struct StreamSet + { + std::istream & in; + std::ostream & out; + std::ostream & err; + }; + virtual DescriptorSet getDescriptorsForClient() = 0; + + virtual DescriptorSet getDescriptorsForServer() = 0; + + virtual StreamSet getStreamsForClient() = 0; + + virtual bool isPty() const = 0; + + virtual void closeServerDescriptors() = 0; + + virtual ~IClientDescriptorSet() = default; +}; + +} diff --git a/src/Server/ClientEmbedded/PipeClientDescriptorSet.h b/src/Server/ClientEmbedded/PipeClientDescriptorSet.h new file mode 100644 index 00000000000..9084e5f272f --- /dev/null +++ b/src/Server/ClientEmbedded/PipeClientDescriptorSet.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include +#include +#include + +namespace DB +{ + + +class PipeClientDescriptorSet : public IClientDescriptorSet +{ +public: + PipeClientDescriptorSet() + : fd_source(pipe_in.readHandle(), boost::iostreams::never_close_handle) + , fd_sink(pipe_out.writeHandle(), boost::iostreams::never_close_handle) + , fd_sink_err(pipe_err.writeHandle(), boost::iostreams::never_close_handle) + , input_stream(fd_source) + , output_stream(fd_sink) + , output_stream_err(fd_sink_err) + { + output_stream << std::unitbuf; + output_stream_err << std::unitbuf; + } + + DescriptorSet getDescriptorsForClient() override + { + return DescriptorSet{.in = pipe_in.readHandle(), .out = pipe_out.writeHandle(), .err = pipe_err.writeHandle()}; + } + + DescriptorSet getDescriptorsForServer() override + { + return DescriptorSet{.in = pipe_in.writeHandle(), .out = pipe_out.readHandle(), .err = pipe_err.readHandle()}; + } + + StreamSet getStreamsForClient() override { return StreamSet{.in = input_stream, .out = output_stream, .err = output_stream_err}; } + + void closeServerDescriptors() override + { + pipe_in.close(Poco::Pipe::CLOSE_WRITE); + pipe_out.close(Poco::Pipe::CLOSE_READ); + pipe_err.close(Poco::Pipe::CLOSE_READ); + } + + bool isPty() const override { return false; } + + ~PipeClientDescriptorSet() override = default; + +private: + Poco::Pipe pipe_in; + Poco::Pipe pipe_out; + Poco::Pipe pipe_err; + + // Provide streams on top of file descriptors + boost::iostreams::file_descriptor_source fd_source; + boost::iostreams::file_descriptor_sink fd_sink; + boost::iostreams::file_descriptor_sink fd_sink_err; + boost::iostreams::stream input_stream; + boost::iostreams::stream output_stream; + boost::iostreams::stream output_stream_err; +}; + +} diff --git a/src/Server/ClientEmbedded/PtyClientDescriptorSet.cpp b/src/Server/ClientEmbedded/PtyClientDescriptorSet.cpp new file mode 100644 index 00000000000..4d613bccbe5 --- /dev/null +++ b/src/Server/ClientEmbedded/PtyClientDescriptorSet.cpp @@ -0,0 +1,77 @@ +#include +#include + +#include "pty.h" + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int SYSTEM_ERROR; +} + +void PtyClientDescriptorSet::FileDescriptorWrapper::close() +{ + if (fd != -1) + { + if (::close(fd) != 0 && errno != EINTR) + throw ErrnoException(ErrorCodes::SYSTEM_ERROR, "Unexpected error while closing file descriptor"); + } + fd = -1; +} + + +PtyClientDescriptorSet::PtyClientDescriptorSet(const String & term_name_, int width, int height, int width_pixels, int height_pixels) + : term_name(term_name_) +{ + winsize winsize{}; + winsize.ws_col = width; + winsize.ws_row = height; + winsize.ws_xpixel = width_pixels; + winsize.ws_ypixel = height_pixels; + int pty_master_raw = -1, pty_slave_raw = -1; + if (openpty(&pty_master_raw, &pty_slave_raw, nullptr, nullptr, &winsize) != 0) + { + throw ErrnoException(ErrorCodes::SYSTEM_ERROR, "Cannot open pty"); + } + pty_master.capture(pty_master_raw); + pty_slave.capture(pty_slave_raw); + fd_source.open(pty_slave.get(), boost::iostreams::never_close_handle); + fd_sink.open(pty_slave.get(), boost::iostreams::never_close_handle); + + // disable signals from tty + struct termios tios; + if (tcgetattr(pty_slave.get(), &tios) == -1) + { + throw ErrnoException(ErrorCodes::SYSTEM_ERROR, "Cannot get termios from tty via tcgetattr"); + } + tios.c_lflag &= ~ISIG; + if (tcsetattr(pty_slave.get(), TCSANOW, &tios) == -1) + { + throw ErrnoException(ErrorCodes::SYSTEM_ERROR, "Cannot set termios to tty via tcsetattr"); + } + input_stream.open(fd_source); + output_stream.open(fd_sink); + output_stream << std::unitbuf; +} + + +void PtyClientDescriptorSet::changeWindowSize(int width, int height, int width_pixels, int height_pixels) const +{ + winsize winsize{}; + winsize.ws_col = width; + winsize.ws_row = height; + winsize.ws_xpixel = width_pixels; + winsize.ws_ypixel = height_pixels; + + if (ioctl(pty_master.get(), TIOCSWINSZ, &winsize) == -1) + { + throw ErrnoException(ErrorCodes::SYSTEM_ERROR, "Cannot update terminal window size via ioctl TIOCSWINSZ"); + } +} + + +PtyClientDescriptorSet::~PtyClientDescriptorSet() = default; + +} diff --git a/src/Server/ClientEmbedded/PtyClientDescriptorSet.h b/src/Server/ClientEmbedded/PtyClientDescriptorSet.h new file mode 100644 index 00000000000..fd9350dc6dc --- /dev/null +++ b/src/Server/ClientEmbedded/PtyClientDescriptorSet.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include +#include +#include "base/types.h" + +namespace DB +{ + + +class PtyClientDescriptorSet : public IClientDescriptorSet +{ +public: + PtyClientDescriptorSet(const String & term_name, int width, int height, int width_pixels, int height_pixels); + + DescriptorSet getDescriptorsForClient() override + { + return DescriptorSet{.in = pty_slave.get(), .out = pty_slave.get(), .err = pty_slave.get()}; + } + + DescriptorSet getDescriptorsForServer() override { return DescriptorSet{.in = pty_master.get(), .out = pty_master.get(), .err = -1}; } + + StreamSet getStreamsForClient() override { return StreamSet{.in = input_stream, .out = output_stream, .err = output_stream}; } + + void changeWindowSize(int width, int height, int width_pixels, int height_pixels) const; + + void closeServerDescriptors() override { pty_master.close(); } + + bool isPty() const override { return true; } + + ~PtyClientDescriptorSet() override; + +private: + class FileDescriptorWrapper + { + public: + FileDescriptorWrapper() = default; + + void capture(int fd_) + { + close(); + fd = fd_; + } + + int get() const { return fd; } + + void close(); + + ~FileDescriptorWrapper() { close(); } // may throw, thus std::terminate + + private: + int fd = -1; + }; + + String term_name; + FileDescriptorWrapper pty_master; + FileDescriptorWrapper pty_slave; + + // Provide streams on top of file descriptors + boost::iostreams::file_descriptor_source fd_source; // handles pty_slave lifetime + boost::iostreams::file_descriptor_sink fd_sink; + boost::iostreams::stream input_stream; + boost::iostreams::stream output_stream; +}; + +} diff --git a/src/Server/SSH/SSHBind.cpp b/src/Server/SSH/SSHBind.cpp new file mode 100644 index 00000000000..d890355b81c --- /dev/null +++ b/src/Server/SSH/SSHBind.cpp @@ -0,0 +1,79 @@ +#include + +#if USE_SSH + +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int SSH_EXCEPTION; +} + +} + +namespace ssh +{ + +SSHBind::SSHBind() : bind(ssh_bind_new(), &deleter) +{ + if (!bind) + { + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to create ssh_bind"); + } +} + +SSHBind::~SSHBind() = default; + +SSHBind::SSHBind(SSHBind && other) noexcept = default; + +SSHBind & SSHBind::operator=(SSHBind && other) noexcept = default; + +void SSHBind::setHostKey(const std::string & key_path) +{ + if (ssh_bind_options_set(bind.get(), SSH_BIND_OPTIONS_HOSTKEY, key_path.c_str()) != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed setting host key in sshbind due to {}", getError()); +} + +String SSHBind::getError() +{ + return String(ssh_get_error(bind.get())); +} + +void SSHBind::disableDefaultConfig() +{ + bool enable = false; + if (ssh_bind_options_set(bind.get(), SSH_BIND_OPTIONS_PROCESS_CONFIG, &enable) != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed disabling default config in sshbind due to {}", getError()); +} + +void SSHBind::setFd(int fd) +{ + ssh_bind_set_fd(bind.get(), fd); +} + +void SSHBind::listen() +{ + if (ssh_bind_listen(bind.get()) != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed listening in sshbind due to {}", getError()); +} + +void SSHBind::acceptFd(SSHSession & session, int fd) +{ + if (ssh_bind_accept_fd(bind.get(), session.getInternalPtr(), fd) != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed accepting fd in sshbind due to {}", getError()); +} + +void SSHBind::deleter(BindPtr bind) +{ + ssh_bind_free(bind); +} + +} + +#endif diff --git a/src/Server/SSH/SSHBind.h b/src/Server/SSH/SSHBind.h new file mode 100644 index 00000000000..d6a6a637365 --- /dev/null +++ b/src/Server/SSH/SSHBind.h @@ -0,0 +1,54 @@ +#pragma once + +#include "config.h" + +#if USE_SSH + +#include +#include +#include +#include +#include + +struct ssh_bind_struct; + +namespace ssh +{ + +// Wrapper around libssh's ssh_bind +class SSHBind +{ +public: + using BindPtr = ssh_bind_struct *; + + SSHBind(); + ~SSHBind(); + + SSHBind(const SSHBind &) = delete; + SSHBind & operator=(const SSHBind &) = delete; + + SSHBind(SSHBind &&) noexcept; + SSHBind & operator=(SSHBind &&) noexcept; + + // Disables libssh's default config + void disableDefaultConfig(); + // Sets host key for a server. It can be set only one time for each key type. + // If you provide different keys of one type, the first one will be overwritten. + void setHostKey(const std::string & key_path); + // Passes external socket to ssh_bind + void setFd(int fd); + // Listens on a socket. If it was passed via setFd just read hostkeys + void listen(); + // Assign accepted socket to ssh_session + void acceptFd(SSHSession & session, int fd); + String getError(); + +private: + static void deleter(BindPtr bind); + + std::unique_ptr bind; +}; + +} + +#endif diff --git a/src/Server/SSH/SSHChannel.cpp b/src/Server/SSH/SSHChannel.cpp new file mode 100644 index 00000000000..f27d3de2aba --- /dev/null +++ b/src/Server/SSH/SSHChannel.cpp @@ -0,0 +1,87 @@ +#include + +#if USE_SSH + +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int SSH_EXCEPTION; +} + +} + +namespace ssh +{ + +SSHChannel::SSHChannel(SSHSession::SessionPtr session) : channel(ssh_channel_new(session), &deleter) +{ + if (!channel) + { + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to create ssh_channel"); + } +} + +SSHChannel::~SSHChannel() = default; + +SSHChannel::SSHChannel(SSHChannel && other) noexcept : channel(std::move(other.channel)) +{ +} + +SSHChannel & SSHChannel::operator=(SSHChannel && other) noexcept +{ + if (this != &other) + { + channel = std::move(other.channel); + } + return *this; +} + +ssh_channel SSHChannel::getCChannelPtr() const +{ + return channel.get(); +} + +int SSHChannel::read(void * dest, uint32_t count, int isStderr) +{ + return ssh_channel_read(channel.get(), dest, count, isStderr); +} + +int SSHChannel::readTimeout(void * dest, uint32_t count, int isStderr, int timeout) +{ + return ssh_channel_read_timeout(channel.get(), dest, count, isStderr, timeout); +} + +int SSHChannel::write(const void * data, uint32_t len) +{ + return ssh_channel_write(channel.get(), data, len); +} + +int SSHChannel::sendEof() +{ + return ssh_channel_send_eof(channel.get()); +} + +int SSHChannel::close() +{ + return ssh_channel_close(channel.get()); +} + +bool SSHChannel::isOpen() +{ + return ssh_channel_is_open(channel.get()) != 0; +} + +void SSHChannel::deleter(ssh_channel ch) +{ + ssh_channel_free(ch); +} + +} + +#endif diff --git a/src/Server/SSH/SSHChannel.h b/src/Server/SSH/SSHChannel.h new file mode 100644 index 00000000000..1346be4ebcf --- /dev/null +++ b/src/Server/SSH/SSHChannel.h @@ -0,0 +1,50 @@ +#pragma once + +#include "config.h" + +#if USE_SSH + +#include +#include + +struct ssh_channel_struct; + +namespace ssh +{ + +// Wrapper around libssh's ssh_channel +class SSHChannel +{ +public: + using ChannelPtr = ssh_channel_struct *; + + explicit SSHChannel(SSHSession::SessionPtr session); + ~SSHChannel(); + + SSHChannel(const SSHChannel &) = delete; + SSHChannel & operator=(const SSHChannel &) = delete; + + SSHChannel(SSHChannel &&) noexcept; + SSHChannel & operator=(SSHChannel &&) noexcept; + + // Exposes ssh_channel c pointer, which could be used to be passed into other objects + ChannelPtr getCChannelPtr() const; + + int read(void * dest, uint32_t count, int isStderr); + int readTimeout(void * dest, uint32_t count, int isStderr, int timeout); + int write(const void * data, uint32_t len); + // Send eof signal to the other side of channel. It does not close the socket. + int sendEof(); + // Sends eof if it has not been sent and then closes channel. + int close(); + bool isOpen(); + +private: + static void deleter(ChannelPtr ch); + + std::unique_ptr channel; +}; + +} + +#endif diff --git a/src/Server/SSH/SSHEvent.cpp b/src/Server/SSH/SSHEvent.cpp new file mode 100644 index 00000000000..229d4061cc9 --- /dev/null +++ b/src/Server/SSH/SSHEvent.cpp @@ -0,0 +1,93 @@ +#include + +#if USE_SSH + +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int SSH_EXCEPTION; +} + +} + +namespace ssh +{ + +SSHEvent::SSHEvent() : event(ssh_event_new(), &deleter) +{ + if (!event) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to create ssh_event"); +} + +SSHEvent::~SSHEvent() = default; + +SSHEvent::SSHEvent(SSHEvent && other) noexcept : event(std::move(other.event)) +{ +} + +SSHEvent & SSHEvent::operator=(SSHEvent && other) noexcept +{ + if (this != &other) + { + event = std::move(other.event); + } + return *this; +} + +void SSHEvent::addSession(SSHSession & session) +{ + if (ssh_event_add_session(event.get(), session.getInternalPtr()) != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Error adding session to ssh event"); +} + +void SSHEvent::removeSession(SSHSession & session) +{ + if (ssh_event_remove_session(event.get(), session.getInternalPtr()) != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Error removing session from ssh event"); +} + +int SSHEvent::poll(int timeout) +{ + while (true) + { + int rc = ssh_event_dopoll(event.get(), timeout); + + if (rc == SSH_AGAIN) + continue; + if (rc != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Error on polling on ssh event: {}", rc); + + return rc; + } +} + +int SSHEvent::poll() +{ + return poll(-1); +} + +void SSHEvent::addFd(int fd, int events, EventCallback cb, void * userdata) +{ + if (ssh_event_add_fd(event.get(), fd, events, cb, userdata) != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Error on adding custom file descriptor to ssh event"); +} + +void SSHEvent::removeFd(socket_t fd) +{ + ssh_event_remove_fd(event.get(), fd); +} + +void SSHEvent::deleter(EventPtr e) +{ + ssh_event_free(e); +} + +} + +#endif diff --git a/src/Server/SSH/SSHEvent.h b/src/Server/SSH/SSHEvent.h new file mode 100644 index 00000000000..5adea6c3ef2 --- /dev/null +++ b/src/Server/SSH/SSHEvent.h @@ -0,0 +1,48 @@ +#pragma once + +#include "config.h" + +#if USE_SSH + +#include +#include + +struct ssh_event_struct; + +namespace ssh +{ + +// Wrapper around ssh_event from libssh +class SSHEvent +{ +public: + using EventPtr = ssh_event_struct *; + using EventCallback = int (*)(int fd, int revents, void * userdata); + + SSHEvent(); + ~SSHEvent(); + + SSHEvent(const SSHEvent &) = delete; + SSHEvent & operator=(const SSHEvent &) = delete; + + SSHEvent(SSHEvent &&) noexcept; + SSHEvent & operator=(SSHEvent &&) noexcept; + + // Adds session's socket to event. Now callbacks will be executed for this session on poll + void addSession(SSHSession & session); + void removeSession(SSHSession & session); + // Add fd to ssh_event and assign callbacks on fd event + void addFd(int fd, int events, EventCallback cb, void * userdata); + void removeFd(int fd); + int poll(int timeout); + int poll(); + +private: + static void deleter(EventPtr e); + + std::unique_ptr event; +}; + +} + +#endif diff --git a/src/Server/SSH/SSHPtyHandler.cpp b/src/Server/SSH/SSHPtyHandler.cpp new file mode 100644 index 00000000000..1cc7c40e951 --- /dev/null +++ b/src/Server/SSH/SSHPtyHandler.cpp @@ -0,0 +1,519 @@ +#include + +#if USE_SSH + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +namespace +{ + +/* +Need to generate adapter functions, such for each member function, for example: + +class SessionCallback +{ +For this: + ssh_channel channelOpen(ssh_session session) + { + channel = SSHChannel(session); + return channel->get(); + } + + +Generate this: + static ssh_channel channelOpenAdapter(ssh_session session, void * userdata) + { + auto * self = static_cast(userdata); + return self->channel_open; + } +} + +We just static cast userdata to our class and then call member function. +This is needed to use c++ classes in libssh callbacks. +Maybe there is a better way? Or just write boilerplate code and avoid macros? +*/ +#define GENERATE_ADAPTER_FUNCTION(class, func_name, return_type) \ + template \ + static return_type func_name##Adapter(Args... args, void * userdata) \ + { \ + auto * self = static_cast(userdata); \ + return self->func_name(args...); \ + } +} + +namespace DB +{ + +namespace +{ + + +// Wrapper around ssh_channel_callbacks. Each callback must not throw any exceptions, as c code is executed +class ChannelCallback +{ +public: + using DescriptorSet = IClientDescriptorSet::DescriptorSet; + + explicit ChannelCallback(::ssh::SSHChannel && channel_, std::unique_ptr && dbSession_) + : channel(std::move(channel_)), db_session(std::move(dbSession_)), log(&Poco::Logger::get("SSHChannelCallback")) + { + channel_cb.userdata = this; + channel_cb.channel_pty_request_function = ptyRequestAdapter; + channel_cb.channel_shell_request_function = shellRequestAdapter; + channel_cb.channel_data_function = dataFunctionAdapter; + channel_cb.channel_pty_window_change_function = ptyResizeAdapter; + channel_cb.channel_env_request_function = envRequestAdapter; + channel_cb.channel_exec_request_function = execRequestAdapter; + channel_cb.channel_subsystem_request_function = subsystemRequestAdapter; + ssh_callbacks_init(&channel_cb) ssh_set_channel_callbacks(channel.getCChannelPtr(), &channel_cb); + } + + bool hasClientFinished() { return client_runner.has_value() && client_runner->hasFinished(); } + + + DescriptorSet client_input_output; + ::ssh::SSHChannel channel; + std::unique_ptr db_session; + NameToNameMap env; + std::optional client_runner; + Poco::Logger * log; + +private: + int ptyRequest(ssh_session, ssh_channel, const char * term, int width, int height, int width_pixels, int height_pixels) noexcept + { + LOG_TRACE(log, "Received pty request"); + if (!db_session || client_runner.has_value()) + return SSH_ERROR; + try + { + auto client_descriptors = std::make_unique(String(term), width, height, width_pixels, height_pixels); + client_runner.emplace(std::move(client_descriptors), std::move(db_session)); + } + catch (...) + { + tryLogCurrentException(log, "Exception from creating pty"); + return SSH_ERROR; + } + + return SSH_OK; + } + + GENERATE_ADAPTER_FUNCTION(ChannelCallback, ptyRequest, int) + + int ptyResize(ssh_session, ssh_channel, int width, int height, int width_pixels, int height_pixels) noexcept + { + LOG_TRACE(log, "Received pty resize"); + if (!client_runner.has_value() || !client_runner->hasPty()) + { + return SSH_ERROR; + } + + try + { + client_runner->changeWindowSize(width, height, width_pixels, height_pixels); + return SSH_OK; + } + catch (...) + { + tryLogCurrentException(log, "Exception from changing window size"); + return SSH_ERROR; + } + } + + GENERATE_ADAPTER_FUNCTION(ChannelCallback, ptyResize, int) + + int dataFunction(ssh_session, ssh_channel, void * data, uint32_t len, int /*is_stderr*/) const noexcept + { + if (len == 0 || client_input_output.in == -1) + { + return 0; + } + + return static_cast(write(client_input_output.in, data, len)); + } + + GENERATE_ADAPTER_FUNCTION(ChannelCallback, dataFunction, int) + + int subsystemRequest(ssh_session, ssh_channel, const char * subsystem) noexcept + { + LOG_TRACE(log, "Received subsystem request"); + if (strcmp(subsystem, "ch-client") != 0) + { + return SSH_ERROR; + } + LOG_TRACE(log, "Subsystem is supported"); + if (!client_runner.has_value() || client_runner->hasStarted() || !client_runner->hasPty()) + { + return SSH_ERROR; + } + + try + { + client_runner->run(env); + client_input_output = client_runner->getDescriptorsForServer(); + return SSH_OK; + } + catch (...) + { + tryLogCurrentException(log, "Exception from starting client"); + return SSH_ERROR; + } + } + + GENERATE_ADAPTER_FUNCTION(ChannelCallback, subsystemRequest, int) + + int shellRequest(ssh_session, ssh_channel) noexcept + { + LOG_TRACE(log, "Received shell request"); + if (!client_runner.has_value() || client_runner->hasStarted() || !client_runner->hasPty()) + { + return SSH_ERROR; + } + + try + { + client_runner->run(env); + client_input_output = client_runner->getDescriptorsForServer(); + return SSH_OK; + } + catch (...) + { + tryLogCurrentException(log, "Exception from starting client"); + return SSH_ERROR; + } + } + + GENERATE_ADAPTER_FUNCTION(ChannelCallback, shellRequest, int) + + int envRequest(ssh_session, ssh_channel, const char * env_name, const char * env_value) + { + LOG_TRACE(log, "Received env request"); + env[env_name] = env_value; + return SSH_OK; + } + + GENERATE_ADAPTER_FUNCTION(ChannelCallback, envRequest, int) + + int execNopty(const String & command) + { + if (db_session) + { + try + { + auto client_descriptors = std::make_unique(); + client_runner.emplace(std::move(client_descriptors), std::move(db_session)); + client_runner->run(env, command); + client_input_output = client_runner->getDescriptorsForServer(); + } + catch (...) + { + tryLogCurrentException(log, "Exception from starting client with no pty"); + return SSH_ERROR; + } + } + return SSH_OK; + } + + int execRequest(ssh_session, ssh_channel, const char * command) + { + LOG_TRACE(log, "Received exec request"); + if (client_runner.has_value() && (client_runner->hasStarted() || !client_runner->hasPty())) + { + return SSH_ERROR; + } + if (client_runner.has_value()) + { + try + { + client_runner->run(env, command); + client_input_output = client_runner->getDescriptorsForServer(); + return SSH_OK; + } + catch (...) + { + tryLogCurrentException(log, "Exception from starting client with pre entered query"); + return SSH_ERROR; + } + } + return execNopty(String(command)); + } + + GENERATE_ADAPTER_FUNCTION(ChannelCallback, execRequest, int) + + + ssh_channel_callbacks_struct channel_cb = {}; +}; + +int process_stdout(socket_t fd, int revents, void * userdata) +{ + char buf[1024]; + int n = -1; + ssh_channel channel = static_cast(userdata); + + if (channel != nullptr && (revents & POLLIN) != 0) + { + n = static_cast(read(fd, buf, 1024)); + if (n > 0) + { + ssh_channel_write(channel, buf, n); + } + } + + return n; +} + +int process_stderr(socket_t fd, int revents, void * userdata) +{ + char buf[1024]; + int n = -1; + ssh_channel channel = static_cast(userdata); + + if (channel != nullptr && (revents & POLLIN) != 0) + { + n = static_cast(read(fd, buf, 1024)); + if (n > 0) + { + ssh_channel_write_stderr(channel, buf, n); + } + } + + return n; +} + +// Wrapper around ssh_server_callbacks. Each callback must not throw any exceptions, as c code is executed +class SessionCallback +{ +public: + explicit SessionCallback(::ssh::SSHSession & session, IServer & server, const Poco::Net::SocketAddress & address_) + : server_context(server.context()), peer_address(address_), log(&Poco::Logger::get("SSHSessionCallback")) + { + server_cb.userdata = this; + server_cb.auth_password_function = authPasswordAdapter; + server_cb.auth_pubkey_function = authPublickeyAdapter; + ssh_set_auth_methods(session.getInternalPtr(), SSH_AUTH_METHOD_PASSWORD | SSH_AUTH_METHOD_PUBLICKEY); + server_cb.channel_open_request_session_function = channelOpenAdapter; + ssh_callbacks_init(&server_cb) ssh_set_server_callbacks(session.getInternalPtr(), &server_cb); + } + + size_t auth_attempts = 0; + bool authenticated = false; + std::unique_ptr db_session; + DB::ContextMutablePtr server_context; + Poco::Net::SocketAddress peer_address; + std::unique_ptr channel_callback; + Poco::Logger * log; + +private: + ssh_channel channelOpen(ssh_session session) noexcept + { + LOG_DEBUG(log, "Opening a channel"); + if (!db_session) + { + return nullptr; + } + try + { + auto channel = ::ssh::SSHChannel(session); + channel_callback = std::make_unique(std::move(channel), std::move(db_session)); + return channel_callback->channel.getCChannelPtr(); + } + catch (...) + { + tryLogCurrentException(log, "Error while opening channel:"); + return nullptr; + } + } + + GENERATE_ADAPTER_FUNCTION(SessionCallback, channelOpen, ssh_channel) + + int authPassword(ssh_session, const char * user, const char * pass) noexcept + { + try + { + LOG_TRACE(log, "Authenticating with password"); + auto db_session_created = std::make_unique(server_context, ClientInfo::Interface::LOCAL); + String user_name(user), password(pass); + db_session_created->authenticate(user_name, password, peer_address); + authenticated = true; + db_session = std::move(db_session_created); + return SSH_AUTH_SUCCESS; + } + catch (...) + { + ++auth_attempts; + return SSH_AUTH_DENIED; + } + } + + GENERATE_ADAPTER_FUNCTION(SessionCallback, authPassword, int) + + int authPublickey(ssh_session, const char * user, ssh_key key, char signature_state) noexcept + { + try + { + LOG_TRACE(log, "Authenticating with public key"); + auto db_session_created = std::make_unique(server_context, ClientInfo::Interface::LOCAL); + String user_name(user); + + + if (signature_state == SSH_PUBLICKEY_STATE_NONE) + { + // This is the case when user wants to check if he is able to use this type of authentication. + // Also here we may check if the key is associated with the user, but current session + // authentication mechanism doesn't support it. + + const auto user_authentication_types = db_session_created->getAuthenticationTypes(user_name); + + for (auto user_authentication_type : user_authentication_types) + if (user_authentication_type == AuthenticationType::SSH_KEY) + + return SSH_AUTH_DENIED; + } + + if (signature_state != SSH_PUBLICKEY_STATE_VALID) + { + ++auth_attempts; + return SSH_AUTH_DENIED; + } + + /// FIXME: Generate random string + String challenge="Hello..."; + + // The signature is checked, so just verify that user is associated with publickey. + // Function will throw if authentication fails. + db_session_created->authenticate(SshCredentials{user_name, SSHKey(key).signString(challenge), challenge}, peer_address); + + + authenticated = true; + db_session = std::move(db_session_created); + return SSH_AUTH_SUCCESS; + } + catch (...) + { + ++auth_attempts; + return SSH_AUTH_DENIED; + } + } + + GENERATE_ADAPTER_FUNCTION(SessionCallback, authPublickey, int) + + ssh_server_callbacks_struct server_cb = {}; +}; + +} + +SSHPtyHandler::SSHPtyHandler( + IServer & server_, + ::ssh::SSHSession session_, + const Poco::Net::StreamSocket & socket, + unsigned int max_auth_attempts_, + unsigned int auth_timeout_seconds_, + unsigned int finish_timeout_seconds_, + unsigned int event_poll_interval_milliseconds_) + : Poco::Net::TCPServerConnection(socket) + , server(server_) + , log(&Poco::Logger::get("SSHPtyHandler")) + , session(std::move(session_)) + , max_auth_attempts(max_auth_attempts_) + , auth_timeout_seconds(auth_timeout_seconds_) + , finish_timeout_seconds(finish_timeout_seconds_) + , event_poll_interval_milliseconds(event_poll_interval_milliseconds_) +{ +} + +void SSHPtyHandler::run() +{ + ::ssh::SSHEvent event; + SessionCallback sdata(session, server, socket().peerAddress()); + session.handleKeyExchange(); + event.addSession(session); + int max_iterations = auth_timeout_seconds * 1000 / event_poll_interval_milliseconds; + int n = 0; + while (!sdata.authenticated || !sdata.channel_callback) + { + /* If the user has used up all attempts, or if he hasn't been able to + * authenticate in auth_timeout_seconds, disconnect. */ + if (sdata.auth_attempts >= max_auth_attempts || n >= max_iterations) + { + return; + } + + if (server.isCancelled()) + { + return; + } + event.poll(event_poll_interval_milliseconds); + n++; + } + bool fds_set = false; + + do + { + /* Poll the main event which takes care of the session, the channel and + * even our client's stdout/stderr (once it's started). */ + event.poll(event_poll_interval_milliseconds); + + /* If client's stdout/stderr has been registered with the event, + * or the client hasn't started yet, continue. */ + if (fds_set || sdata.channel_callback->client_input_output.out == -1) + { + continue; + } + /* Executed only once, once the client starts. */ + fds_set = true; + + /* If stdout valid, add stdout to be monitored by the poll event. */ + if (sdata.channel_callback->client_input_output.out != -1) + { + event.addFd(sdata.channel_callback->client_input_output.out, POLLIN, process_stdout, sdata.channel_callback->channel.getCChannelPtr()); + } + if (sdata.channel_callback->client_input_output.err != -1) + { + event.addFd(sdata.channel_callback->client_input_output.err, POLLIN, process_stderr, sdata.channel_callback->channel.getCChannelPtr()); + } + } while (sdata.channel_callback->channel.isOpen() && !sdata.channel_callback->hasClientFinished() && !server.isCancelled()); + + LOG_DEBUG( + log, + "Finishing connection with state: channel open: {}, embedded client finished: {}, server cancelled: {}", + sdata.channel_callback->channel.isOpen(), sdata.channel_callback->hasClientFinished(), server.isCancelled() + ); + + event.removeFd(sdata.channel_callback->client_input_output.out); + event.removeFd(sdata.channel_callback->client_input_output.err); + + + sdata.channel_callback->channel.sendEof(); + sdata.channel_callback->channel.close(); + + /* Wait up to finish_timeout_seconds seconds for the client to terminate the session. */ + max_iterations = finish_timeout_seconds * 1000 / event_poll_interval_milliseconds; + for (n = 0; n < max_iterations && !session.hasFinished(); n++) + { + event.poll(event_poll_interval_milliseconds); + } + LOG_DEBUG(log, "Connection closed"); +} + +} + +#endif diff --git a/src/Server/SSH/SSHPtyHandler.h b/src/Server/SSH/SSHPtyHandler.h new file mode 100644 index 00000000000..2038788741d --- /dev/null +++ b/src/Server/SSH/SSHPtyHandler.h @@ -0,0 +1,43 @@ +#pragma once + +#include "config.h" + +#if USE_SSH + +#include +#include +#include +#include + +namespace DB +{ + +class SSHPtyHandler : public Poco::Net::TCPServerConnection +{ +public: + explicit SSHPtyHandler + ( + IServer & server_, + ::ssh::SSHSession session_, + const Poco::Net::StreamSocket & socket, + unsigned int max_auth_attempts_, + unsigned int auth_timeout_seconds_, + unsigned int finish_timeout_seconds_, + unsigned int event_poll_interval_milliseconds_ + ); + + void run() override; + +private: + IServer & server; + Poco::Logger * log; + ::ssh::SSHSession session; + unsigned int max_auth_attempts; + unsigned int auth_timeout_seconds; + unsigned int finish_timeout_seconds; + unsigned int event_poll_interval_milliseconds; +}; + +} + +#endif diff --git a/src/Server/SSH/SSHPtyHandlerFactory.h b/src/Server/SSH/SSHPtyHandlerFactory.h new file mode 100644 index 00000000000..35e7a52f81b --- /dev/null +++ b/src/Server/SSH/SSHPtyHandlerFactory.h @@ -0,0 +1,120 @@ +#pragma once + +#include "config.h" + +#if USE_SSH + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Poco +{ +class Logger; +} +namespace DB +{ + +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} + +class SSHPtyHandlerFactory : public TCPServerConnectionFactory +{ +private: + IServer & server; + Poco::Logger * log; + ::ssh::SSHBind bind; + unsigned int max_auth_attempts; + unsigned int auth_timeout_seconds; + unsigned int finish_timeout_seconds; + unsigned int event_poll_interval_milliseconds; + std::optional read_write_timeout_seconds; // optional here, as libssh has its own defaults + std::optional read_write_timeout_micro_seconds; + +public: + explicit SSHPtyHandlerFactory( + IServer & server_, const Poco::Util::AbstractConfiguration & config) + : server(server_), log(&Poco::Logger::get("SSHHandlerFactory")) + { + LOG_INFO(log, "Initializing sshbind"); + bind.disableDefaultConfig(); + + + String prefix = "ssh."; + auto rsa_key = config.getString(prefix + "host_rsa_key", ""); + auto ecdsa_key = config.getString(prefix + "host_ecdsa_key", ""); + auto ed25519_key = config.getString(prefix + "host_ed25519_key", ""); + max_auth_attempts = config.getUInt("max_auth_attempts", 4); + auth_timeout_seconds = config.getUInt("auth_timeout_seconds", 10); + finish_timeout_seconds = config.getUInt("finish_timeout_seconds", 5); + event_poll_interval_milliseconds = config.getUInt("event_poll_interval_milliseconds", 100); + if (config.has("read_write_timeout_seconds")) + { + read_write_timeout_seconds = config.getInt("read_write_timeout_seconds"); + if (read_write_timeout_seconds.value() < 0) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Negative timeout specified"); + } + if (config.has("read_write_timeout_micro_seconds")) + { + + read_write_timeout_micro_seconds = config.getInt("read_write_timeout_micro_seconds"); + if (read_write_timeout_micro_seconds.value() < 0) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Negative timeout specified"); + } + + if (event_poll_interval_milliseconds == 0) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Poll interval must be positive"); + if (auth_timeout_seconds * 1000 < event_poll_interval_milliseconds) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Poll interval exceeds auth timeout"); + if (finish_timeout_seconds * 1000 < event_poll_interval_milliseconds) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Poll interval exceeds finish timeout"); + + + if (rsa_key.empty() && ecdsa_key.empty() && ed25519_key.empty()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Host key for ssh endpoint is not initialized"); + if (!rsa_key.empty()) + bind.setHostKey(rsa_key); + if (!ecdsa_key.empty()) + bind.setHostKey(ecdsa_key); + if (!ed25519_key.empty()) + bind.setHostKey(ed25519_key); + } + + Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer &) override + { + LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString()); + ::ssh::libsshLogger::initialize(); + ::ssh::SSHSession session; + session.disableSocketOwning(); + session.disableDefaultConfig(); + if (read_write_timeout_seconds.has_value() || read_write_timeout_micro_seconds.has_value()) + { + session.setTimeout(read_write_timeout_seconds.value_or(0), read_write_timeout_micro_seconds.value_or(0)); + } + bind.acceptFd(session, socket.sockfd()); + + return new SSHPtyHandler( + server, + std::move(session), + socket, + max_auth_attempts, + auth_timeout_seconds, + finish_timeout_seconds, + event_poll_interval_milliseconds); + } +}; + +} + +#endif diff --git a/src/Server/SSH/SSHSession.cpp b/src/Server/SSH/SSHSession.cpp new file mode 100644 index 00000000000..a7030aeed53 --- /dev/null +++ b/src/Server/SSH/SSHSession.cpp @@ -0,0 +1,124 @@ +#include + +#if USE_SSH + +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int SSH_EXCEPTION; +} + +} + +namespace ssh +{ + +SSHSession::SSHSession() : session(ssh_new()) +{ + if (!session) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to create ssh_session"); +} + +SSHSession::~SSHSession() +{ + if (session) + ssh_free(session); +} + +SSHSession::SSHSession(SSHSession && rhs) noexcept +{ + *this = std::move(rhs); +} + +SSHSession & SSHSession::operator=(SSHSession && rhs) noexcept +{ + this->session = rhs.session; + rhs.session = nullptr; + return *this; +} + +SSHSession::SessionPtr SSHSession::getInternalPtr() const +{ + return session; +} + +void SSHSession::connect() +{ + int rc = ssh_connect(session); + if (rc != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed connecting in ssh session due due to {}", getError()); + +} + +void SSHSession::disableDefaultConfig() +{ + bool enable = false; + int rc = ssh_options_set(session, SSH_OPTIONS_PROCESS_CONFIG, &enable); + if (rc != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed disabling default config for ssh session due due to {}", getError()); +} + +void SSHSession::disableSocketOwning() +{ + bool owns_socket = false; + int rc = ssh_options_set(session, SSH_OPTIONS_OWNS_SOCKET, &owns_socket); + if (rc != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed disabling socket owning for ssh session due to {}", getError()); +} + +void SSHSession::setPeerHost(const String & host) +{ + int rc = ssh_options_set(session, SSH_OPTIONS_HOST, host.c_str()); + if (rc != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed setting peer host option for ssh session due due to {}", getError()); +} + +void SSHSession::setFd(int fd) +{ + int rc = ssh_options_set(session, SSH_OPTIONS_FD, &fd); + if (rc != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed setting fd option for ssh session due due to {}", getError()); +} + +void SSHSession::setTimeout(int timeout, int timeout_usec) +{ + int rc = ssh_options_set(session, SSH_OPTIONS_TIMEOUT, &timeout); + if (rc != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed setting for ssh session due timeout option due to {}", getError()); + + rc |= ssh_options_set(session, SSH_OPTIONS_TIMEOUT_USEC, &timeout_usec); + if (rc != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed setting for ssh session due timeout_usec option due to {}", getError()); +} + +void SSHSession::handleKeyExchange() +{ + int rc = ssh_handle_key_exchange(session); + if (rc != SSH_OK) + throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed key exchange for ssh session due to {}", getError()); +} + +void SSHSession::disconnect() +{ + ssh_disconnect(session); +} + +String SSHSession::getError() +{ + return String(ssh_get_error(session)); +} + +bool SSHSession::hasFinished() +{ + return ssh_get_status(session) & (SSH_CLOSED | SSH_CLOSED_ERROR); +} + +} + +#endif diff --git a/src/Server/SSH/SSHSession.h b/src/Server/SSH/SSHSession.h new file mode 100644 index 00000000000..db994a7877b --- /dev/null +++ b/src/Server/SSH/SSHSession.h @@ -0,0 +1,60 @@ +#pragma once + +#include "config.h" + +#if USE_SSH + +#include +#include + +struct ssh_session_struct; + +namespace ssh +{ + +// Wrapper around libssh's ssh_session +class SSHSession +{ +public: + SSHSession(); + ~SSHSession(); + SSHSession(SSHSession &&) noexcept; + SSHSession & operator=(SSHSession &&) noexcept; + + SSHSession(const SSHSession &) = delete; + SSHSession & operator=(const SSHSession &) = delete; + + using SessionPtr = ssh_session_struct *; + /// Get raw pointer from libssh to be able to pass it to other objects + SessionPtr getInternalPtr() const; + + /// Disable reading default libssh configuration + void disableDefaultConfig(); + /// Disable session from closing socket. Can be used when a socket is passed. + void disableSocketOwning(); + + /// Connect / disconnect + void connect(); + void disconnect(); + + /// Configure session + void setPeerHost(const String & host); + // Pass ready socket to session + void setFd(int fd); + void setTimeout(int timeout, int timeout_usec); + + void handleKeyExchange(); + + /// Error handling + String getError(); + + // Check that session was closed + bool hasFinished(); + +private: + SessionPtr session = nullptr; +}; + +} + +#endif diff --git a/src/Server/ServerType.cpp b/src/Server/ServerType.cpp index b0511632e6e..f0e02d79fea 100644 --- a/src/Server/ServerType.cpp +++ b/src/Server/ServerType.cpp @@ -47,6 +47,7 @@ bool ServerType::shouldStart(Type server_type, const std::string & server_custom switch (current_type) { case Type::TCP: + case Type::TCP_SSH: case Type::TCP_WITH_PROXY: case Type::TCP_SECURE: case Type::HTTP: @@ -110,6 +111,9 @@ bool ServerType::shouldStop(const std::string & port_name) const else if (port_name == "tcp_port") port_type = Type::TCP; + else if (port_name == "tcp_ssh_port") + port_type = Type::TCP_SSH; + else if (port_name == "tcp_with_proxy_port") port_type = Type::TCP_WITH_PROXY; diff --git a/src/Server/ServerType.h b/src/Server/ServerType.h index c31fb663811..3a45314c4ff 100644 --- a/src/Server/ServerType.h +++ b/src/Server/ServerType.h @@ -13,6 +13,7 @@ public: { TCP_WITH_PROXY, TCP_SECURE, + TCP_SSH, TCP, HTTP, HTTPS, diff --git a/src/Server/TCPServer.h b/src/Server/TCPServer.h index 219fed5342b..3ec8ca98cbd 100644 --- a/src/Server/TCPServer.h +++ b/src/Server/TCPServer.h @@ -37,6 +37,8 @@ public: UInt16 portNumber() const { return port_number; } + const Poco::Net::ServerSocket& getSocket() { return socket; } + private: TCPServerConnectionFactory::Ptr factory; Poco::Net::ServerSocket socket; diff --git a/tests/integration/test_ssh/__init__.py b/tests/integration/test_ssh/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/integration/test_ssh/configs/server.xml b/tests/integration/test_ssh/configs/server.xml new file mode 100644 index 00000000000..39abc9adf95 --- /dev/null +++ b/tests/integration/test_ssh/configs/server.xml @@ -0,0 +1,6 @@ + + 9022 + + /etc/clickhouse-server/config.d/ssh_host_rsa_key + + diff --git a/tests/integration/test_ssh/configs/users.xml b/tests/integration/test_ssh/configs/users.xml new file mode 100644 index 00000000000..0d6a4fb4a32 --- /dev/null +++ b/tests/integration/test_ssh/configs/users.xml @@ -0,0 +1,12 @@ + + + + + + ssh-ed25519 + AAAAC3NzaC1lZDI1NTE5AAAAIA5p06mOZGpz7ePU57OmQ08v3U+CpWa2u1f9/V/yoZ1n + + + + + diff --git a/tests/integration/test_ssh/keys/lucy_ed25519 b/tests/integration/test_ssh/keys/lucy_ed25519 new file mode 100644 index 00000000000..28ea4ea80ac --- /dev/null +++ b/tests/integration/test_ssh/keys/lucy_ed25519 @@ -0,0 +1,7 @@ +-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW +QyNTUxOQAAACAOadOpjmRqc+3j1OezpkNPL91PgqVmtrtX/f1f8qGdZwAAAKDbJgdb2yYH +WwAAAAtzc2gtZWQyNTUxOQAAACAOadOpjmRqc+3j1OezpkNPL91PgqVmtrtX/f1f8qGdZw +AAAECcs+RQOSe++AHZNDgAPwkfwd6t6e6HLN7c0ZDFXAGJ0g5p06mOZGpz7ePU57OmQ08v +3U+CpWa2u1f9/V/yoZ1nAAAAFnVidW50dUBpcC0xMC0xMC0xMC0xMDQBAgMEBQYH +-----END OPENSSH PRIVATE KEY----- diff --git a/tests/integration/test_ssh/keys/lucy_ed25519.pub b/tests/integration/test_ssh/keys/lucy_ed25519.pub new file mode 100644 index 00000000000..a482ac4ac40 --- /dev/null +++ b/tests/integration/test_ssh/keys/lucy_ed25519.pub @@ -0,0 +1 @@ +ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIA5p06mOZGpz7ePU57OmQ08v3U+CpWa2u1f9/V/yoZ1n test diff --git a/tests/integration/test_ssh/keys/ssh_host_rsa_key b/tests/integration/test_ssh/keys/ssh_host_rsa_key new file mode 100644 index 00000000000..c2b3fc6e819 --- /dev/null +++ b/tests/integration/test_ssh/keys/ssh_host_rsa_key @@ -0,0 +1,49 @@ +-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAACFwAAAAdzc2gtcn +NhAAAAAwEAAQAAAgEArzrm/JkXzklKKcz1IKY6ivokvWmvpJRx3TmJP/hA/hGQ+YZGen7A +2RD+enL+pwnqjknGBMFfyaGJWFfOHVlYxZ4BTkpUXIayVm1hiTnfoNMvXN+dKbf8K/u95U +VLPhW08e1+ymB+ZBal7lSZL2gLcyDT4Riy8Yu3kDuxo88GRp3hXU+Ygn7FMTxYM+yCj8rx +9nIqwYA721uXTmADPYZLh9j+SJI4nJN0eiQ/9IrMRJDTeEPi3M0pXLXIdC13j4O9+xY8Rw +PVOoQl+5ntEWgfWzu1LLf6ODqZKr1V/yMuHKXCKIqbwmb/pfwaec0yaDDFQT48DsI14+pB +O3fL7cPzqBCEeeDugXyiyb83zNy9qPLAdd0w16tkYrNYp9C/LRZP+M6GvFPX9PJ/130Yf7 +q3GeRJAv3bL3MZ9pbqFJpzzrVgAeqRXF/C73zGENaffKLGFJ5CyxZKI+VfhtIbrp23PDDz +JFTMBQuy9omdevfqRXuDPEhuwFvpTkQumzX3OHNIXKiyoAwDd4wXRNWWqj7eEPNUATeY/k +lbHfRGjXFNNhDCptrekT6mCYGYo+BzEoKGui5onoS4yx1hMrykka1SzzAKlljqSY/ezlTg +5Z+xcxD9XWYZXFannB7LJZvGAJ1QWeoH48FGBNKOTjKlOiNRL2JuvlhPVWkEk/3AdcPdGn +0AAAdQDlyKDQ5cig0AAAAHc3NoLXJzYQAAAgEArzrm/JkXzklKKcz1IKY6ivokvWmvpJRx +3TmJP/hA/hGQ+YZGen7A2RD+enL+pwnqjknGBMFfyaGJWFfOHVlYxZ4BTkpUXIayVm1hiT +nfoNMvXN+dKbf8K/u95UVLPhW08e1+ymB+ZBal7lSZL2gLcyDT4Riy8Yu3kDuxo88GRp3h +XU+Ygn7FMTxYM+yCj8rx9nIqwYA721uXTmADPYZLh9j+SJI4nJN0eiQ/9IrMRJDTeEPi3M +0pXLXIdC13j4O9+xY8RwPVOoQl+5ntEWgfWzu1LLf6ODqZKr1V/yMuHKXCKIqbwmb/pfwa +ec0yaDDFQT48DsI14+pBO3fL7cPzqBCEeeDugXyiyb83zNy9qPLAdd0w16tkYrNYp9C/LR +ZP+M6GvFPX9PJ/130Yf7q3GeRJAv3bL3MZ9pbqFJpzzrVgAeqRXF/C73zGENaffKLGFJ5C +yxZKI+VfhtIbrp23PDDzJFTMBQuy9omdevfqRXuDPEhuwFvpTkQumzX3OHNIXKiyoAwDd4 +wXRNWWqj7eEPNUATeY/klbHfRGjXFNNhDCptrekT6mCYGYo+BzEoKGui5onoS4yx1hMryk +ka1SzzAKlljqSY/ezlTg5Z+xcxD9XWYZXFannB7LJZvGAJ1QWeoH48FGBNKOTjKlOiNRL2 +JuvlhPVWkEk/3AdcPdGn0AAAADAQABAAACAAK4gt9zzu25idqH5GU0qQSY66knYl0X4VR+ +/1V24kOaTLP6if4pUgDls8gxkh76/TtdbFvY4GzZ1XsnizIE1Zai4gd0qsZA7wDwuJ2o77 +vCZlioplUAZydX5eD5K9CYbcxPLGVE4lIZjtAvp+YN9of+LzdGBRG1A4Y+SbNZg9RgFkHO +g8gyLrP4bg9HS58B50pH9KtdZy9WM/gYEFMJ1zvUlIeBFPgg9cSGU04vvlYD6Na8S7KgX3 +ZtJKlJ36kW/8/9ENIykT3dcmlPB7LKjBEN5Bx2sk83TkyEa4SviNZBcPzbAyLRuSU7IJOq +uFviHwL75auHli2fL96DHvth8eUKD1UeissdnvVRcV/FYE2GXabBLDr4YIn9h4bLJ9lI/d +eDDPFagASWJisk1PfGwpEdpw6tZrYEerOkQTIwBRShB4xpYcCJ4eACafL64U9Az2jFdxK3 +gaLUr1Bkhus48Pp4YIGrhClFz3yIPTkS9XmIBlnrxa64cNH0A+i9yaAe6esvTPRyDRT3MB +oPwZQczzuwQaODEVFz5a+L4U2pg4vBK2StLm0166R1lz0czXeFksx97F+8CA/gvPooVtI4 +6SZpqJrsQ3ajjmQ/OYDWn8owzb2o463xy2MNsy94cMb7CkO+OU8tV8ytW0hzigFvRRtB3P +YoOygUHIYE7f35hUMjAAABACI1GE2wzgqoqKmtTQtSY4LfX86Y4j2W+7tscirzPG0XcGrh +fl2XHht21n7C2LSU3AoXgT9uyzBrn1oCrV+bUusi7wnRQ2DStLlIIegz+lTbZfa+CV3km4 +mJh/BCeunRq/XrJXN2XMNPcVcR4ProqVdTBYXW2gOJYC+wkrawuKZPCDXS3BOKXlPJI3/m +sggxzHIZL/4ndrez/e4QR8GTpRzT6gM4PbgtEePL7XbPjWsD4BrOWM4wZmf7T2z4ohc/bj +4hnJiAokhKoir/4wKZGJ7OA1b7rW/mFFpIMG+1oVqQZuJwa9HkbStLjQCJu5xywlUVJtqK +BXy4kV4+SbsId1kAAAEBAN5JsuqpJnQN22S3iq4KIHIL1/vJCYm/xweVYAhgyJ1DLfPH12 +EEwUT72apAG4URlarL8hcgJYKZ73JZuzVH6nW6xLS8U0NeIblIh74qwV008i8av9qprCfV +LexyRx0Zf5g85GA72kZTKTCD0LdzjjMHGoWcx34aAaPSKvQdsd8anttRL34LefG9VQWQwL +IQsZatwfFVM9O7gtN++0Dqd8kAmlHEQcMvpmbhPfpsr+MIG6sO4MRqGJ+8ior2YCdK7cUX +59PlVv6A1B7Fh3WVKfYT+s+yejGMua1aX3auLX3kEHFG0ckjK1i2r5wEbT6Tvjgz7sGFOx ++HfwvkCWz2Bk8AAAEBAMnOMJr4RH8xHI/NIvYB8+0m9inkVCKiDRVIKJQPeOLgc/5jWTTN +uErVIsN+gQXxIhdG3/0tWu/sVIASfGAkbYuvSzRT/bWfTe5Oc6gNoBEC5vGzHwNaJVZnPe +k40ON1mXtVlJ167Ku3xx5IekyIKnWxAnk/00tX7Ig5vDspFIsQuXmUP3Zs29ZlfkPi7ebO +884o15meC+104mbFLgNzKdm+XrQuP+is3sQ7CzB45wSRNwVlBJSWq+qIt0/LyPsX9fShr3 +rRKl3lyUcrgRexwYXI3bnP/LwnvUROToi6/30NE7rcELXZ29E+R4H45GOEGmt8wIU3bJyX +T5tk8CPnK3MAAAAWdWJ1bnR1QGlwLTEwLTEwLTEwLTEwNAECAwQF +-----END OPENSSH PRIVATE KEY----- diff --git a/tests/integration/test_ssh/keys/ssh_host_rsa_key.pub b/tests/integration/test_ssh/keys/ssh_host_rsa_key.pub new file mode 100644 index 00000000000..bb78b0108c5 --- /dev/null +++ b/tests/integration/test_ssh/keys/ssh_host_rsa_key.pub @@ -0,0 +1 @@ +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQCvOub8mRfOSUopzPUgpjqK+iS9aa+klHHdOYk/+ED+EZD5hkZ6fsDZEP56cv6nCeqOScYEwV/JoYlYV84dWVjFngFOSlRchrJWbWGJOd+g0y9c350pt/wr+73lRUs+FbTx7X7KYH5kFqXuVJkvaAtzINPhGLLxi7eQO7GjzwZGneFdT5iCfsUxPFgz7IKPyvH2cirBgDvbW5dOYAM9hkuH2P5Ikjick3R6JD/0isxEkNN4Q+LczSlctch0LXePg737FjxHA9U6hCX7me0RaB9bO7Ust/o4OpkqvVX/Iy4cpcIoipvCZv+l/Bp5zTJoMMVBPjwOwjXj6kE7d8vtw/OoEIR54O6BfKLJvzfM3L2o8sB13TDXq2Ris1in0L8tFk/4zoa8U9f08n/XfRh/urcZ5EkC/dsvcxn2luoUmnPOtWAB6pFcX8LvfMYQ1p98osYUnkLLFkoj5V+G0huunbc8MPMkVMwFC7L2iZ169+pFe4M8SG7AW+lORC6bNfc4c0hcqLKgDAN3jBdE1ZaqPt4Q81QBN5j+SVsd9EaNcU02EMKm2t6RPqYJgZij4HMSgoa6LmiehLjLHWEyvKSRrVLPMAqWWOpJj97OVODln7FzEP1dZhlcVqecHsslm8YAnVBZ6gfjwUYE0o5OMqU6I1EvYm6+WE9VaQST/cB1w90afQ== test diff --git a/tests/integration/test_ssh/test.py b/tests/integration/test_ssh/test.py new file mode 100644 index 00000000000..b60f1922519 --- /dev/null +++ b/tests/integration/test_ssh/test.py @@ -0,0 +1,45 @@ +import subprocess +import pytest +import os +from helpers.cluster import ClickHouseCluster + +cluster = ClickHouseCluster(__file__) +instance = cluster.add_instance( + "node", + main_configs=["configs/server.xml", "keys/ssh_host_rsa_key"], + user_configs=["configs/users.xml"], +) + +SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) + + +@pytest.fixture(scope="module", autouse=True) +def started_cluster(): + try: + cluster.start() + yield cluster + + finally: + cluster.shutdown() + + +def test_simple_query_with_openssh_client(): + ssh_command = ( + "ssh -o StrictHostKeyChecking" + + f"=no lucy@{instance.ip_address} -p 9022" + + f' -i {SCRIPT_DIR}/keys/lucy_ed25519 "select 1"' + ) + + completed_process = subprocess.run( + ssh_command, + shell=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + expected = instance.query("select 1") + + output = completed_process.stdout + + assert output.replace("\n\x00", "\n") == expected diff --git a/utils/check-style/check-style b/utils/check-style/check-style index c3b42be1519..fd572465dfb 100755 --- a/utils/check-style/check-style +++ b/utils/check-style/check-style @@ -328,7 +328,10 @@ std_cerr_cout_excludes=( src/Processors/IProcessor.cpp src/Client/ClientApplicationBase.cpp src/Client/ClientBase.cpp + src/Common/ProgressIndication.h + src/Client/LineReader.h src/Client/LineReader.cpp + src/Client/ReplxxLineReader.h src/Client/QueryFuzzer.cpp src/Client/Suggest.cpp src/Client/ClientBase.h @@ -348,7 +351,7 @@ sources_with_std_cerr_cout=( $( ) ) # Exclude comments for src in "${sources_with_std_cerr_cout[@]}"; do - # suppress stderr, since it may contain warning for #pargma once in headers + # suppress stderr, since it may contain warning for #pragma once in headers if gcc -fpreprocessed -dD -E "$src" 2>/dev/null | grep -F -q -e std::cerr -e std::cout; then echo "$src: uses std::cerr/std::cout" fi