SSH Authentication

This commit is contained in:
Gamezardashvili George 2024-11-19 14:58:50 +00:00 committed by Nikita Mikhaylov
parent 1aceb608f3
commit 41416952c8
71 changed files with 2946 additions and 80 deletions

View File

@ -8,6 +8,10 @@ if (GLIBC_COMPATIBILITY)
add_headers_and_sources(glibc_compatibility .)
add_headers_and_sources(glibc_compatibility musl)
# There are several symbols which exist only in musl.
set (musl_pty_include_dir musl/pty)
if (ARCH_AARCH64)
list (APPEND glibc_compatibility_sources musl/aarch64/syscall.s musl/aarch64/longjmp.s)
set (musl_arch_include_dir musl/aarch64)
@ -40,6 +44,7 @@ if (GLIBC_COMPATIBILITY)
target_no_warning(glibc-compatibility unused-but-set-variable)
target_no_warning(glibc-compatibility builtin-requires-header)
target_include_directories(glibc-compatibility PUBLIC ${musl_pty_include_dir})
target_include_directories(glibc-compatibility PRIVATE libcxxabi ${musl_arch_include_dir})
if (ENABLE_OPENSSL_DYNAMIC)

View File

@ -0,0 +1,10 @@
#include "libc.h"
struct __libc __libc;
size_t __hwcap;
size_t __sysinfo;
char *__progname=0, *__progname_full=0;
weak_alias(__progname, program_invocation_short_name);
weak_alias(__progname_full, program_invocation_name);

View File

@ -0,0 +1,72 @@
#ifndef LIBC_H
#define LIBC_H
#include <stdlib.h>
#include <stdio.h>
#include <limits.h>
struct __locale_map;
struct __locale_struct {
const struct __locale_map *volatile cat[6];
};
struct tls_module {
struct tls_module *next;
void *image;
size_t len, size, align, offset;
};
struct __libc {
int can_do_threads;
int threaded;
int secure;
volatile int threads_minus_1;
size_t *auxv;
struct tls_module *tls_head;
size_t tls_size, tls_align, tls_cnt;
size_t page_size;
struct __locale_struct global_locale;
};
#ifndef PAGE_SIZE
#define PAGE_SIZE libc.page_size
#endif
#ifdef __PIC__
#define ATTR_LIBC_VISIBILITY __attribute__((visibility("hidden")))
#else
#define ATTR_LIBC_VISIBILITY
#endif
extern struct __libc __libc ATTR_LIBC_VISIBILITY;
#define libc __libc
extern size_t __hwcap ATTR_LIBC_VISIBILITY;
extern size_t __sysinfo ATTR_LIBC_VISIBILITY;
extern char *__progname, *__progname_full;
/* Designed to avoid any overhead in non-threaded processes */
void __lock(volatile int *) ATTR_LIBC_VISIBILITY;
void __unlock(volatile int *) ATTR_LIBC_VISIBILITY;
int __lockfile(FILE *) ATTR_LIBC_VISIBILITY;
void __unlockfile(FILE *) ATTR_LIBC_VISIBILITY;
#define LOCK(x) __lock(x)
#define UNLOCK(x) __unlock(x)
void __synccall(void (*)(void *), void *);
int __setxid(int, int, int, int);
extern char **__environ;
#undef weak_alias
#define weak_alias(old, new) \
extern __typeof(old) new __attribute__((weak, alias(#old)))
#undef LFS64_2
#define LFS64_2(x, y) weak_alias(x, y)
#undef LFS64
#define LFS64(x) LFS64_2(x, x##64)
#endif

View File

@ -0,0 +1,40 @@
#include <stdlib.h>
#include <fcntl.h>
#include <unistd.h>
#include "pty.h"
#include <stdio.h>
#include <pthread.h>
/* Nonstandard, but vastly superior to the standard functions */
int openpty(int *pm, int *ps, char *name, const struct termios *tio, const struct winsize *ws)
{
int m, s, n=0, cs;
char buf[20];
m = open("/dev/ptmx", O_RDWR|O_NOCTTY);
if (m < 0) return -1;
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
if (ioctl(m, TIOCSPTLCK, &n) || ioctl (m, TIOCGPTN, &n))
goto fail;
if (!name) name = buf;
snprintf(name, sizeof buf, "/dev/pts/%d", n);
if ((s = open(name, O_RDWR|O_NOCTTY)) < 0)
goto fail;
if (tio) tcsetattr(s, TCSANOW, tio);
if (ws) ioctl(s, TIOCSWINSZ, ws);
*pm = m;
*ps = s;
pthread_setcancelstate(cs, 0);
return 0;
fail:
close(m);
pthread_setcancelstate(cs, 0);
return -1;
}

View File

@ -0,0 +1,34 @@
#include <stdlib.h>
#include <sys/ioctl.h>
#include <stdio.h>
#include <fcntl.h>
#include <errno.h>
#include "libc.h"
#include "syscall.h"
int posix_openpt(int flags)
{
return open("/dev/ptmx", flags);
}
int grantpt(int fd)
{
return 0;
}
int unlockpt(int fd)
{
int unlock = 0;
return ioctl(fd, TIOCSPTLCK, &unlock);
}
int __ptsname_r(int fd, char *buf, size_t len)
{
int pty, err;
if (!buf) len = 0;
if ((err = __syscall(SYS_ioctl, fd, TIOCGPTN, &pty))) return -err;
if (snprintf(buf, len, "/dev/pts/%d", pty) >= len) return ERANGE;
return 0;
}
weak_alias(__ptsname_r, ptsname_r);

View File

@ -0,0 +1,15 @@
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
#include <termios.h>
#include <sys/ioctl.h>
int openpty(int *, int *, char *, const struct termios *, const struct winsize *);
int forkpty(int *, char *, const struct termios *, const struct winsize *);
#ifdef __cplusplus
}
#endif

View File

@ -362,7 +362,7 @@ namespace Util
void setArgs(int argc, char * argv[]);
void setArgs(const ArgVec & args);
void getApplicationPath(Poco::Path & path) const;
void processOptions();
void processPocoOptions();
bool findAppConfigFile(const std::string & appName, const std::string & extension, Poco::Path & path) const;
bool findAppConfigFile(const Path & basePath, const std::string & appName, const std::string & extension, Poco::Path & path) const;

View File

@ -149,7 +149,7 @@ void Application::init()
_pConfig->setString("application.cacheDir", Path::cacheHome() + appPath.getBaseName() + Path::separator());
_pConfig->setString("application.tempDir", Path::tempHome() + appPath.getBaseName() + Path::separator());
_pConfig->setString("application.dataDir", Path::dataHome() + appPath.getBaseName() + Path::separator());
processOptions();
processPocoOptions();
}
@ -368,7 +368,7 @@ void Application::setArgs(const ArgVec& args)
}
void Application::processOptions()
void Application::processPocoOptions()
{
defineOptions(_options);
OptionProcessor processor(_options);

View File

@ -0,0 +1,9 @@
---
slug: /en/interfaces/ssh
sidebar_position: 19
sidebar_label: SSH Interface
---
# SSH Interface
You can connect to clickhouse-server using any ssh client!

View File

@ -370,6 +370,8 @@ try
showClientVersion();
}
default_database = config().getString("database", "");
try
{
connect();
@ -458,8 +460,11 @@ void Client::connect()
{
try
{
const auto host = ConnectionParameters::Host{hosts_and_ports[attempted_address_index].host};
const auto database = ConnectionParameters::Database{default_database};
connection_parameters = ConnectionParameters(
config(), hosts_and_ports[attempted_address_index].host, hosts_and_ports[attempted_address_index].port);
config(), host, database, hosts_and_ports[attempted_address_index].port);
if (is_interactive)
std::cout << "Connecting to "
@ -508,6 +513,10 @@ void Client::connect()
server_version = toString(server_version_major) + "." + toString(server_version_minor) + "." + toString(server_version_patch);
load_suggestions = is_interactive && (server_revision >= Suggest::MIN_SERVER_REVISION) && !config().getBool("disable_suggestion", false);
wait_for_suggestions_to_load = config().getBool("wait_for_suggestions_to_load", false);
if (load_suggestions)
{
suggestion_limit = config().getInt("suggestion_limit");
}
if (server_display_name = connection->getServerDisplayName(connection_parameters.timeouts); server_display_name.empty())
server_display_name = config().getString("host", "localhost");
@ -1194,10 +1203,42 @@ void Client::processConfig()
echo_queries = config().getBool("echo", false);
ignore_error = config().getBool("ignore-error", false);
auto query_id = config().getString("query_id", "");
query_id = config().getString("query_id", "");
if (!query_id.empty())
global_context->setCurrentQueryId(query_id);
}
if (is_interactive || delayed_interactive)
{
if (home_path.empty())
{
const char * home_path_cstr = getenv("HOME"); // NOLINT(concurrency-mt-unsafe)
if (home_path_cstr)
home_path = home_path_cstr;
}
/// Load command history if present.
if (config().has("history_file"))
history_file = config().getString("history_file");
else
{
auto * history_file_from_env = getenv("CLICKHOUSE_HISTORY_FILE"); // NOLINT(concurrency-mt-unsafe)
if (history_file_from_env)
history_file = history_file_from_env;
else if (!home_path.empty())
history_file = home_path + "/.clickhouse-client-history";
}
}
if (config().has("query"))
{
static_query = config().getRawString("query"); /// Poco configuration should not process substitutions in form of ${...} inside query.
}
pager = config().getString("pager", "");
enable_highlight = config().getBool("highlight", true);
multiline = config().has("multiline");
print_stack_trace = config().getBool("stacktrace", false);
pager = config().getString("pager", "");

View File

@ -2,8 +2,9 @@
#include "Commands.h"
#include <Client/ReplxxLineReader.h>
#include <Client/ClientBase.h>
#include "Common/VersionNumber.h"
#include <Common/VersionNumber.h>
#include <Common/Config/ConfigProcessor.h>
#include <Client/ClientApplicationBase.h>
#include <Common/EventNotifier.h>
#include <Common/filesystemHelpers.h>
#include <Common/ZooKeeper/ZooKeeper.h>
@ -328,7 +329,8 @@ void KeeperClient::runInteractiveReplxx()
query_extenders,
query_delimiters,
word_break_characters,
/* highlighter_= */ {});
/* highlighter_= */ {}
);
lr.enableBracketedPaste();
while (true)

View File

@ -488,7 +488,11 @@ void LocalServer::setupUsers()
void LocalServer::connect()
{
connection_parameters = ConnectionParameters(getClientConfiguration(), "localhost");
connection_parameters = ConnectionParameters(
config(),
ConnectionParameters::Host{"localhost"},
ConnectionParameters::Database{default_database}
);
/// This is needed for table function input(...).
ReadBuffer * in;
@ -502,6 +506,7 @@ void LocalServer::connect()
input = std::make_unique<ReadBufferFromFile>(table_file);
in = input.get();
}
connection = LocalConnection::createConnection(
connection_parameters, client_context, in, need_render_progress, need_render_profile_events, server_display_name);
}

View File

@ -128,6 +128,9 @@
#if USE_SSL
# include <Poco/Net/SecureServerSocket.h>
# include <Server/CertificateReloader.h>
# include <Server/SSH/SSHPtyHandlerFactory.h>
# include <Common/LibSSHInitializer.h>
# include <Common/LibSSHLogger.h>
#endif
#if USE_GRPC
@ -373,6 +376,16 @@ static std::string getCanonicalPath(std::string && path)
return std::move(path);
}
Server::Server()
{
#if USE_SSL
::ssh::LibSSHInitializer::instance();
::ssh::libsshLogger::initialize();
#endif
}
Poco::Net::SocketAddress Server::socketBindListen(
const Poco::Util::AbstractConfiguration & config,
Poco::Net::ServerSocket & socket,
@ -2813,6 +2826,37 @@ void Server::createServers(
});
}
if (server_type.shouldStart(ServerType::Type::TCP_SSH))
{
port_name = "tcp_ssh_port";
createServer(
config,
listen_host,
port_name,
listen_try,
start_servers,
servers,
[&](UInt16 port) -> ProtocolServerAdapter
{
#if USE_SSH
Poco::Net::ServerSocket socket;
auto address = socketBindListen(config, socket, listen_host, port, /* secure = */ false);
return ProtocolServerAdapter(
listen_host,
port_name,
"SSH pty: " + address.toString(),
std::make_unique<TCPServer>(
new SSHPtyHandlerFactory(*this, config),
server_pool,
socket,
new Poco::Net::TCPServerParams));
#else
UNUSED(port);
throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "SSH protocol is disabled for ClickHouse, as it has been built without libssh");
#endif
});
}
if (server_type.shouldStart(ServerType::Type::MYSQL))
{
port_name = "mysql_port";

View File

@ -34,6 +34,9 @@ class ProtocolServerAdapter;
class Server : public BaseDaemon, public IServer
{
public:
Server();
using ServerApplication::run;
Poco::Util::LayeredConfiguration & config() const override

View File

@ -190,7 +190,12 @@ void AuthenticationData::setPasswordHashHex(const String & hash, bool validate)
String AuthenticationData::getPasswordHashHex() const
{
if (type == AuthenticationType::LDAP || type == AuthenticationType::KERBEROS || type == AuthenticationType::SSL_CERTIFICATE)
if (
type == AuthenticationType::LDAP
|| type == AuthenticationType::KERBEROS
|| type == AuthenticationType::SSL_CERTIFICATE
|| type == AuthenticationType::SSH_KEY
)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot get password hex hash for authentication type {}", toString(type));
String hex;

View File

@ -125,4 +125,5 @@ private:
};
#endif
}

View File

@ -0,0 +1,154 @@
#include <stdexcept>
#include <Access/SSH/SSHPublicKey.h>
#include <Common/Exception.h>
#include <Common/clibssh.h>
namespace DB
{
namespace ErrorCodes
{
extern const int SSH_EXCEPTION;
extern const int LOGICAL_ERROR;
extern const int BAD_ARGUMENTS;
}
}
namespace ssh
{
SSHPublicKey::SSHPublicKey(KeyPtr key_, bool own) : key(key_, own ? &deleter : &disabledDeleter)
{ // disable deleter if class is constructed without ownership
if (!key)
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "No ssh_key provided in explicit constructor");
}
}
SSHPublicKey::~SSHPublicKey() = default;
SSHPublicKey::SSHPublicKey(const SSHPublicKey & other) : key(ssh_key_dup(other.key.get()), &deleter)
{
if (!key)
{
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to duplicate ssh_key");
}
}
SSHPublicKey & SSHPublicKey::operator=(const SSHPublicKey & other)
{
if (this != &other)
{
KeyPtr new_key = ssh_key_dup(other.key.get());
if (!new_key)
{
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to duplicate ssh_key");
}
key = UniqueKeyPtr(new_key, deleter); // We don't have access to the pointer from external code, opposed to non owning key object.
// So here we always go for default deleter, regardless of other's
}
return *this;
}
SSHPublicKey::SSHPublicKey(SSHPublicKey && other) noexcept = default;
SSHPublicKey & SSHPublicKey::operator=(SSHPublicKey && other) noexcept = default;
bool SSHPublicKey::operator==(const SSHPublicKey & other) const
{
return isEqual(other);
}
bool SSHPublicKey::isEqual(const SSHPublicKey & other) const
{
int rc = ssh_key_cmp(key.get(), other.key.get(), SSH_KEY_CMP_PUBLIC);
return rc == 0;
}
SSHPublicKey SSHPublicKey::createFromBase64(const String & base64, const String & key_type)
{
KeyPtr key;
int rc = ssh_pki_import_pubkey_base64(base64.c_str(), ssh_key_type_from_name(key_type.c_str()), &key);
if (rc != SSH_OK)
{
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed importing public key from base64 format.\n\
Key: {}\n\
Type: {}",
base64, key_type
);
}
return SSHPublicKey(key);
}
SSHPublicKey SSHPublicKey::createFromFile(const std::string & filename)
{
KeyPtr key;
int rc = ssh_pki_import_pubkey_file(filename.c_str(), &key);
if (rc != SSH_OK)
{
if (rc == SSH_EOF)
{
throw DB::Exception(
DB::ErrorCodes::BAD_ARGUMENTS,
"Can't import ssh public key from file {} as it doesn't exist or permission denied", filename
);
}
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Can't import ssh public key from file {}", filename);
}
return SSHPublicKey(key);
}
SSHPublicKey SSHPublicKey::createNonOwning(KeyPtr key)
{
return SSHPublicKey(key, false);
}
namespace
{
struct CStringDeleter
{
[[maybe_unused]] void operator()(char * ptr) const { std::free(ptr); }
};
}
String SSHPublicKey::getBase64Representation() const
{
char * buf = nullptr;
int rc = ssh_pki_export_pubkey_base64(key.get(), &buf);
if (rc != SSH_OK)
{
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to export public key to base64");
}
// Create a String from cstring, which makes a copy of the first one and requires freeing memory after it
std::unique_ptr<char, CStringDeleter> buf_ptr(buf); // This is to safely manage buf memory
return String(buf_ptr.get());
}
String SSHPublicKey::getType() const
{
const char * type_c = ssh_key_type_to_char(ssh_key_type(key.get()));
if (type_c == nullptr)
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Key type is unknown or no key contained");
}
return String(type_c);
}
std::size_t SSHPublicKey::KeyHasher::operator()(const SSHPublicKey & input_key) const
{
String combined_string(input_key.getType());
combined_string += input_key.getBase64Representation();
return string_hasher(combined_string);
}
void SSHPublicKey::deleter(KeyPtr key)
{
ssh_key_free(key);
}
}

View File

@ -0,0 +1,67 @@
#pragma once
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_set>
#include <base/types.h>
struct ssh_key_struct;
namespace ssh
{
class SSHPublicKey
{
private:
class KeyHasher
{
public:
std::size_t operator()(const SSHPublicKey & input_key) const;
private:
std::hash<std::string> string_hasher;
};
public:
using KeyPtr = ssh_key_struct *;
using KeySet = std::unordered_set<SSHPublicKey, KeyHasher>;
SSHPublicKey() = delete;
~SSHPublicKey();
SSHPublicKey(const SSHPublicKey &);
SSHPublicKey & operator=(const SSHPublicKey &);
SSHPublicKey(SSHPublicKey &&) noexcept;
SSHPublicKey & operator=(SSHPublicKey &&) noexcept;
bool operator==(const SSHPublicKey &) const;
bool isEqual(const SSHPublicKey & other) const;
String getBase64Representation() const;
String getType() const;
static SSHPublicKey createFromBase64(const String & base64, const String & key_type);
static SSHPublicKey createFromFile(const String & filename);
// Creates SSHPublicKey, but without owning the memory of ssh_key.
// A user must manage it by himself. (This is implemented for compatibility with libssh callbacks)
static SSHPublicKey createNonOwning(KeyPtr key);
private:
explicit SSHPublicKey(KeyPtr key, bool own = true);
static void deleter(KeyPtr key);
// We may want to not own ssh_key memory, so then we pass this deleter to unique_ptr
static void disabledDeleter(KeyPtr) { }
using UniqueKeyPtr = std::unique_ptr<ssh_key_struct, decltype(&deleter)>;
UniqueKeyPtr key;
};
}

View File

@ -219,6 +219,9 @@ macro(add_object_library name common_path)
endmacro()
add_object_library(clickhouse_access Access)
if (TARGET ch_contrib::ssh)
add_object_library(clickhouse_access_ssh Access/SSH)
endif()
add_object_library(clickhouse_backups Backups)
add_object_library(clickhouse_core Core)
add_object_library(clickhouse_core_mysql Core/MySQL)
@ -259,6 +262,10 @@ set_source_files_properties(Client/ClientBaseOptimizedParts.cpp PROPERTIES COMPI
add_object_library(clickhouse_bridge BridgeHelper)
add_object_library(clickhouse_server Server)
add_object_library(clickhouse_server_http Server/HTTP)
if (TARGET ch_contrib::ssh)
add_object_library(clickhouse_server_ssh Server/SSH)
endif()
add_object_library(clickhouse_server_embedded_client Server/ClientEmbedded)
add_object_library(clickhouse_formats Formats)
add_object_library(clickhouse_processors Processors)
add_object_library(clickhouse_processors_executors Processors/Executors)

View File

@ -47,6 +47,8 @@ protected:
void setupSignalHandler() override;
void addMultiquery(std::string_view query, Arguments & common_arguments) const;
virtual void readArguments(int argc, char ** argv, Arguments & common_arguments, std::vector<Arguments> &, std::vector<Arguments> &) = 0;
private:
void parseAndCheckOptions(OptionsDescription & options_description, po::variables_map & options, Arguments & arguments);

View File

@ -115,11 +115,14 @@ namespace Setting
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
extern const int DEADLOCK_AVOIDED;
extern const int DATABASE_ACCESS_DENIED;
extern const int CLIENT_OUTPUT_FORMAT_SPECIFIED;
extern const int UNKNOWN_PACKET_FROM_SERVER;
extern const int NO_DATA_TO_INSERT;
extern const int UNEXPECTED_PACKET_FROM_SERVER;
extern const int INCORRECT_FILE_NAME;
extern const int INVALID_USAGE_OF_INPUT;
extern const int CANNOT_SET_SIGNAL_HANDLER;
extern const int LOGICAL_ERROR;
@ -544,6 +547,7 @@ try
{
if (!output_format)
{
auto is_embedded = global_context->getApplicationType() == Context::ApplicationType::SERVER;
/// Ignore all results when fuzzing as they can be huge.
if (query_fuzzer_runs)
{
@ -552,7 +556,7 @@ try
}
WriteBuffer * out_buf = nullptr;
if (!pager.empty())
if (!pager.empty() && !is_embedded)
{
if (SIG_ERR == signal(SIGPIPE, SIG_IGN))
throw ErrnoException(ErrorCodes::CANNOT_SET_SIGNAL_HANDLER, "Cannot set signal handler for SIGPIPE");
@ -583,8 +587,12 @@ try
/// The query can specify output format or output file.
if (const auto * query_with_output = dynamic_cast<const ASTQueryWithOutput *>(parsed_query.get()))
{
if (query_with_output->out_file && is_embedded)
{
error_stream << "Out files are disabled when you are running client embedded into server. Ignoring this option.\n";
}
String out_file;
if (query_with_output->out_file)
if (query_with_output->out_file && !is_embedded)
{
select_into_file = true;
@ -628,7 +636,7 @@ try
{
select_into_file_and_stdout = true;
out_file_buf = std::make_unique<ForkWriteBuffer>(std::vector<WriteBufferPtr>{std::move(out_file_buf),
std::make_shared<WriteBufferFromFileDescriptor>(STDOUT_FILENO)});
std::make_shared<WriteBufferFromFileDescriptor>(out_fd)});
}
// We are writing to file, so default format is the same as in non-interactive mode.
@ -642,7 +650,7 @@ try
const auto & id = query_with_output->format->as<ASTIdentifier &>();
current_format = id.name();
}
else if (query_with_output->out_file)
else if (query_with_output->out_file && !is_embedded)
{
auto format_name = FormatFactory::instance().tryGetFormatFromFileName(out_file);
if (format_name)
@ -691,7 +699,7 @@ void ClientBase::initLogsOutputStream()
if (server_logs_file.empty())
{
/// Use stderr by default
out_logs_buf = std::make_unique<AutoCanceledWriteBuffer<WriteBufferFromFileDescriptor>>(STDERR_FILENO);
out_logs_buf = std::make_unique<AutoCanceledWriteBuffer<WriteBufferFromFileDescriptor>>(err_fd);
wb = out_logs_buf.get();
color_logs = stderr_is_a_tty;
}
@ -852,12 +860,22 @@ void ClientBase::initTTYBuffer(ProgressOption progress_option, ProgressOption pr
/// use ProgressOption that was set for the progress bar for progress table as well.
ProgressOption progress = progress_option ? progress_option : progress_table_option;
static constexpr auto tty_file_name = "/dev/tty";
/// Output all progress bar commands to terminal at once to avoid flicker.
/// This size is usually greater than the window size.
static constexpr size_t buf_size = 1024;
// If we are embedded into server, there is no need to access terminal device via opening a file.
// Actually we need to pass tty's name, if we don't want this condition statement,
// because /dev/tty stands for controlling terminal of the process, thus a client will not see progress line.
// So it's easier to just pass a descriptor, without the terminal name.
if (global_context->getApplicationType() == Context::ApplicationType::SERVER)
{
tty_buf = std::make_unique<WriteBufferFromFileDescriptor>(out_fd, buf_size);
return;
}
static constexpr auto tty_file_name = "/dev/tty";
if (is_interactive || progress == ProgressOption::TTY)
{
std::error_code ec;
@ -892,7 +910,7 @@ void ClientBase::initTTYBuffer(ProgressOption progress_option, ProgressOption pr
if (stderr_is_a_tty || progress == ProgressOption::ERR)
{
tty_buf = std::make_unique<AutoCanceledWriteBuffer<WriteBufferFromFileDescriptor>>(STDERR_FILENO, buf_size);
tty_buf = std::make_unique<AutoCanceledWriteBuffer<WriteBufferFromFileDescriptor>>(err_fd, buf_size);
}
else
{
@ -1509,7 +1527,7 @@ void ClientBase::resetOutput()
if (SIG_ERR == signal(SIGQUIT, SIG_DFL))
throw ErrnoException(ErrorCodes::CANNOT_SET_SIGNAL_HANDLER, "Cannot set signal handler for SIGQUIT");
setupSignalHandler();
// setupSignalHandler();
}
pager_cmd = nullptr;
@ -1612,6 +1630,27 @@ void ClientBase::processInsertQuery(const String & query_to_execute, ASTPtr pars
throw Exception(ErrorCodes::NO_DATA_TO_INSERT, "No data to insert");
return;
}
// Validate infile before we pass further, as some files may be unsafe if client is embedded into server
if (global_context->getApplicationType() == Context::ApplicationType::SERVER && parsed_insert_query.infile)
{
const auto & in_file_node = parsed_insert_query.infile->as<ASTLiteral &>();
const auto in_file = in_file_node.value.safeGet<std::string>();
String user_files_absolute_path = fs::weakly_canonical(global_context->getUserFilesPath());
fs::path fs_table_path(in_file);
if (fs_table_path.is_relative())
fs_table_path = user_files_absolute_path / fs_table_path;
/// Do not use fs::canonical or fs::weakly_canonical.
/// Otherwise it will not allow to work with symlinks in `user_files_path` directory.
String path = fs::absolute(fs_table_path).lexically_normal(); /// Normalize path.
auto table_path_stat = fs::status(path);
if (!fs::exists(table_path_stat))
throw Exception(ErrorCodes::INCORRECT_FILE_NAME, "Provided file doesn't exist: {}", in_file);
if (!fileOrSymlinkPathStartsWith(path, user_files_absolute_path))
throw Exception(ErrorCodes::DATABASE_ACCESS_DENIED, "File `{}` is not inside `{}`", path, user_files_absolute_path);
}
query_interrupt_handler.start();
SCOPE_EXIT({ query_interrupt_handler.stop(); });
@ -1641,7 +1680,10 @@ void ClientBase::processInsertQuery(const String & query_to_execute, ASTPtr pars
{
/// If structure was received (thus, server has not thrown an exception),
/// send our data with that structure.
if (global_context->getApplicationType() != Context::ApplicationType::SERVER)
{
setInsertionTable(parsed_insert_query);
}
sendData(sample, columns_description, parsed_query);
receiveEndOfQuery();
@ -2122,6 +2164,8 @@ void ClientBase::processParsedSingleQuery(const String & full_query, const Strin
{
const String & new_database = use_query->getDatabase();
/// If the client initiates the reconnection, it takes the settings from the config.
/// TODO: Revisit
default_database = new_database;
getClientConfiguration().setString("database", new_database);
/// If the connection initiates the reconnection, it uses its variable.
connection->setDefaultDatabase(new_database);
@ -2459,13 +2503,13 @@ bool ClientBase::executeMultiQuery(const String & all_queries_text)
if (!server_exception)
{
error_matches_hint = false;
fmt::print(stderr, "Expected server error code '{}' but got no server error (query: {}).\n",
error_stream << fmt::format("Expected server error code '{}' but got no server error (query: {}).\n",
test_hint.serverErrors(), full_query);
}
else if (!test_hint.hasExpectedServerError(server_exception->code()))
{
error_matches_hint = false;
fmt::print(stderr, "Expected server error code: {} but got: {} (query: {}).\n",
error_stream << fmt::format("Expected server error code: {} but got: {} (query: {}).\n",
test_hint.serverErrors(), server_exception->code(), full_query);
}
}
@ -2474,13 +2518,13 @@ bool ClientBase::executeMultiQuery(const String & all_queries_text)
if (!client_exception)
{
error_matches_hint = false;
fmt::print(stderr, "Expected client error code '{}' but got no client error (query: {}).\n",
error_stream << fmt::format("Expected client error code '{}' but got no client error (query: {}).\n",
test_hint.clientErrors(), full_query);
}
else if (!test_hint.hasExpectedClientError(client_exception->code()))
{
error_matches_hint = false;
fmt::print(stderr, "Expected client error code '{}' but got '{}' (query: {}).\n",
error_stream << fmt::format("Expected client error code '{}' but got '{}' (query: {}).\n",
test_hint.clientErrors(), client_exception->code(), full_query);
}
}
@ -2497,14 +2541,14 @@ bool ClientBase::executeMultiQuery(const String & all_queries_text)
if (test_hint.hasClientErrors())
{
error_matches_hint = false;
fmt::print(stderr,
error_stream << fmt::format(
"The query succeeded but the client error '{}' was expected (query: {}).\n",
test_hint.clientErrors(), full_query);
}
if (test_hint.hasServerErrors())
{
error_matches_hint = false;
fmt::print(stderr,
error_stream << fmt::format(
"The query succeeded but the server error '{}' was expected (query: {}).\n",
test_hint.serverErrors(), full_query);
}
@ -2737,33 +2781,54 @@ void ClientBase::runInteractive()
LineReader::Patterns query_delimiters = {";", "\\G", "\\G;"};
char word_break_characters[] = " \t\v\f\a\b\r\n`~!@#$%^&*()-=+[{]}\\|;:'\",<.>/?";
std::unique_ptr<LineReader> lr;
#if USE_REPLXX
replxx::Replxx::highlighter_callback_t highlight_callback{};
if (getClientConfiguration().getBool("highlight", true))
highlight_callback = [this](const String & query, std::vector<replxx::Replxx::Color> & colors)
{
highlight(query, colors, *client_context);
};
ReplxxLineReader lr(
String actual_history_file_path;
if (global_context->getApplicationType() != Context::ApplicationType::SERVER)
actual_history_file_path = history_file;
lr = std::make_unique<ReplxxLineReader>(
*suggest,
history_file,
actual_history_file_path,
history_max_entries,
getClientConfiguration().has("multiline"),
getClientConfiguration().getBool("ignore_shell_suspend", true),
query_extenders,
query_delimiters,
word_break_characters,
highlight_callback);
highlight_callback,
input_stream,
output_stream,
in_fd,
out_fd,
err_fd
);
#else
(void)word_break_characters;
LineReader lr(
lr = LineReader(
history_file,
getClientConfiguration().has("multiline"),
query_extenders,
query_delimiters);
query_delimiters,
word_break_characters,
input_stream,
output_stream,
in_fd
);
#endif
/// Enable bracketed-paste-mode so that we are able to paste multiline queries as a whole.
lr->enableBracketedPaste();
static const std::initializer_list<std::pair<String, String>> backslash_aliases =
{
{ "\\l", "SHOW DATABASES" },
@ -2787,10 +2852,10 @@ void ClientBase::runInteractive()
/// But keep it disabled outside of query input, because it breaks password input
/// (e.g. if we need to reconnect and show a password prompt).
/// (Alternatively, we could make the password input ignore the control sequences.)
lr.enableBracketedPaste();
SCOPE_EXIT({ lr.disableBracketedPaste(); });
lr->enableBracketedPaste();
SCOPE_EXIT({ lr->disableBracketedPaste(); });
input = lr.readLine(prompt(), ":-] ");
input = lr->readLine(prompt(), ":-] ");
}
if (input.empty())
@ -2929,7 +2994,7 @@ void ClientBase::runNonInteractive()
{
/// If 'query' parameter is not set, read a query from stdin.
/// The query is read entirely into memory (streaming is disabled).
ReadBufferFromFileDescriptor in(STDIN_FILENO);
ReadBufferFromFileDescriptor in(in_fd);
String text;
readStringUntilEOF(text, in);
if (query_fuzzer_runs)

View File

@ -70,6 +70,7 @@ enum ProgressOption
ProgressOption toProgressOption(std::string progress);
std::istream& operator>> (std::istream & in, ProgressOption & progress);
class InternalTextLogs;
class TerminalKeystrokeInterceptor;
class WriteBufferFromFileDescriptor;
@ -132,6 +133,8 @@ protected:
static void adjustQueryEnd(const char *& this_query_end, const char * all_queries_end, uint32_t max_parser_depth, uint32_t max_parser_backtracks);
virtual void setupSignalHandler() = 0;
ASTPtr parseQuery(const char *& pos, const char * end, bool allow_multi_statements) const;
bool executeMultiQuery(const String & all_queries_text);
MultiQueryProcessingStage analyzeMultiQueryText(
const char *& this_query_begin, const char *& this_query_end, const char * all_queries_end,
@ -163,13 +166,6 @@ protected:
/// Returns true if query processing was successful.
bool processQueryText(const String & text);
virtual void readArguments(
int argc,
char ** argv,
Arguments & common_arguments,
std::vector<Arguments> & external_tables_arguments,
std::vector<Arguments> & hosts_and_ports_arguments) = 0;
void setInsertionTable(const ASTInsertQuery & insert_query);
private:
@ -223,6 +219,7 @@ protected:
void start(Int32 signals_before_stop = 1) { exit_after_signals.store(signals_before_stop); }
/// Set value not greater then 0 to mark the query as stopped.
void stop() { exit_after_signals.store(0); }
/// Return true if the query was stopped.
@ -257,12 +254,19 @@ protected:
/// Should be one of the first, to be destroyed the last,
/// since other members can use them.
SharedContextHolder shared_context;
SharedContextHolder shared_context; // maybe not initialized
ContextMutablePtr global_context;
/// Client context is a context used only by the client to parse queries, process query parameters and to connect to clickhouse-server.
ContextMutablePtr client_context;
String default_database;
String query_id;
Int32 suggestion_limit;
bool enable_highlight = true;
bool multiline = false;
String static_query;
std::unique_ptr<TerminalKeystrokeInterceptor> keystroke_interceptor;
bool is_interactive = false; /// Use either interactive line editing interface or batch mode.
@ -307,7 +311,7 @@ protected:
MergeTreeSettings cmd_merge_tree_settings;
/// thread status should be destructed before shared context because it relies on process list.
std::optional<ThreadStatus> thread_status;
std::optional<ThreadStatus> thread_status; // may be not initialized in embedded client
ServerConnectionPtr connection;
ConnectionParameters connection_parameters;
@ -328,6 +332,7 @@ protected:
std::unique_ptr<InternalTextLogs> logs_out_stream;
/// /dev/tty if accessible or std::cerr - for progress bar.
/// But running embedded into server, we write the progress to given tty file dexcriptor.
/// We prefer to output progress bar directly to tty to allow user to redirect stdout and stderr and still get the progress indication.
std::unique_ptr<WriteBufferFromFileDescriptor> tty_buf;
std::mutex tty_mutex;

View File

@ -11,7 +11,7 @@
#include <IO/TimeoutSetter.h>
#include <Formats/NativeReader.h>
#include <Formats/NativeWriter.h>
#include <Client/ClientBase.h>
#include <Client/ClientApplicationBase.h>
#include <Client/Connection.h>
#include <Client/ConnectionParameters.h>
#include "Common/logger_useful.h"

View File

@ -12,7 +12,6 @@
#include <readpassphrase/readpassphrase.h>
namespace DB
{
@ -39,15 +38,32 @@ bool enableSecureConnection(const Poco::Util::AbstractConfiguration & config, co
}
ConnectionParameters::ConnectionParameters(const Poco::Util::AbstractConfiguration & config,
std::string connection_host,
std::optional<UInt16> connection_port)
: host(connection_host)
, port(connection_port.value_or(getPortFromConfig(config, connection_host)))
ConnectionParameters ConnectionParameters::createForEmbedded(const String & user, const String & database)
{
security = enableSecureConnection(config, connection_host) ? Protocol::Secure::Enable : Protocol::Secure::Disable;
auto connection_params = ConnectionParameters();
connection_params.host = "localhost";
connection_params.security = Protocol::Secure::Disable;
connection_params.password = "";
connection_params.user = user;
connection_params.default_database = database;
connection_params.compression = Protocol::Compression::Disable;
default_database = config.getString("database", "");
// TODO: Pass settings struct.
// connection_params.timeouts = ConnectionTimeouts::getTCPTimeoutsWithFailover(getGlobal);
connection_params.timeouts.sync_request_timeout = Poco::Timespan(DBMS_DEFAULT_SYNC_REQUEST_TIMEOUT_SEC, 0);
return connection_params;
}
ConnectionParameters::ConnectionParameters(const Poco::Util::AbstractConfiguration & config,
const Host & host_,
const Database & database,
std::optional<UInt16> port_)
: host(host_)
, port(port_.value_or(getPortFromConfig(config, host_)))
, default_database(database)
{
security = enableSecureConnection(config, host_) ? Protocol::Secure::Enable : Protocol::Secure::Disable;
/// changed the default value to "default" to fix the issue when the user in the prompt is blank
user = config.getString("user", "default");
@ -139,10 +155,10 @@ ConnectionParameters::ConnectionParameters(const Poco::Util::AbstractConfigurati
Poco::Timespan(config.getInt("sync_request_timeout", DBMS_DEFAULT_SYNC_REQUEST_TIMEOUT_SEC), 0));
}
ConnectionParameters::ConnectionParameters(const Poco::Util::AbstractConfiguration & config,
std::string connection_host)
: ConnectionParameters(config, config.getString("host", "localhost"), getPortFromConfig(config, connection_host))
ConnectionParameters::ConnectionParameters(const Poco::Util::AbstractConfiguration & config_, const Host & host_, const Database & database_)
: ConnectionParameters(config_, host_, database_, getPortFromConfig(config_, host_))
{
}
UInt16 ConnectionParameters::getPortFromConfig(const Poco::Util::AbstractConfiguration & config,

View File

@ -1,5 +1,6 @@
#pragma once
#include <base/strong_typedef.h>
#include <Common/SSHWrapper.h>
#include <Core/Protocol.h>
#include <IO/ConnectionTimeouts.h>
@ -29,15 +30,20 @@ struct ConnectionParameters
Protocol::Compression compression = Protocol::Compression::Enable;
ConnectionTimeouts timeouts;
using Database = StrongTypedef<String, struct DatabaseTag>;
using Host = StrongTypedef<String, struct HostTag>;
ConnectionParameters() = default;
ConnectionParameters(const Poco::Util::AbstractConfiguration & config, String host);
ConnectionParameters(const Poco::Util::AbstractConfiguration & config, String host, std::optional<UInt16> port);
ConnectionParameters(const Poco::Util::AbstractConfiguration & config_, const Host & host_, const Database & database_);
ConnectionParameters(const Poco::Util::AbstractConfiguration & config_, const Host & host_, const Database & database_, std::optional<UInt16> port_);
static UInt16 getPortFromConfig(const Poco::Util::AbstractConfiguration & config, const std::string & connection_host);
/// Ask to enter the user's password if password option contains this value.
/// "\n" is used because there is hardly a chance that a user would use '\n' as password.
static constexpr std::string_view ASK_PASSWORD = "\n";
static ConnectionParameters createForEmbedded(const String & user, const String & database);
};
}

View File

@ -59,8 +59,11 @@ namespace DB
/// Allows delaying the start of query execution until the entirety of query is inserted.
bool LineReader::hasInputData() const
{
pollfd fd{in_fd, POLLIN, 0};
return poll(&fd, 1, 0) == 1;
timeval timeout = {0, 0};
fd_set fds{};
FD_ZERO(&fds);
FD_SET(in_fd, &fds);
return select(1, &fds, nullptr, nullptr, &timeout) == 1;
}
replxx::Replxx::completions_t LineReader::Suggest::getCompletions(const String & prefix, size_t prefix_length, const char * word_break_characters)
@ -131,7 +134,8 @@ void LineReader::Suggest::addWords(Words && new_words) // NOLINT(cppcoreguidelin
}
}
LineReader::LineReader(
LineReader::LineReader
(
const String & history_file_path_,
bool multiline_,
Patterns extenders_,

View File

@ -1,6 +1,7 @@
#include "LocalConnection.h"
#include <memory>
#include <Client/ClientBase.h>
#include <Client/ClientApplicationBase.h>
#include <Core/Protocol.h>
#include <Interpreters/DatabaseCatalog.h>
#include <Interpreters/executeQuery.h>
@ -46,15 +47,25 @@ namespace ErrorCodes
LocalConnection::LocalConnection(ContextPtr context_, ReadBuffer * in_, bool send_progress_, bool send_profile_events_, const String & server_display_name_)
: WithContext(context_)
, session(getContext(), ClientInfo::Interface::LOCAL)
, session(std::make_unique<Session>(getContext(), ClientInfo::Interface::LOCAL))
, send_progress(send_progress_)
, send_profile_events(send_profile_events_)
, server_display_name(server_display_name_)
, in(in_)
{
/// Authenticate and create a context to execute queries.
session.authenticate("default", "", Poco::Net::SocketAddress{});
session.makeSessionContext();
session->authenticate("default", "", Poco::Net::SocketAddress{});
session->makeSessionContext();
}
LocalConnection::LocalConnection(
std::unique_ptr<Session> && session_, bool send_progress_, bool send_profile_events_, const String & server_display_name_)
: WithContext(session_->sessionContext())
, session(std::move(session_))
, send_progress(send_progress_)
, send_profile_events(send_profile_events_)
, server_display_name(server_display_name_)
{
}
LocalConnection::~LocalConnection()
@ -115,9 +126,9 @@ void LocalConnection::sendQuery(
/// Suggestion comes without client_info.
if (client_info)
query_context = session.makeQueryContext(*client_info);
query_context = session->makeQueryContext(*client_info);
else
query_context = session.makeQueryContext();
query_context = session->makeQueryContext();
query_context->setCurrentQueryId(query_id);
if (send_progress)
@ -169,6 +180,7 @@ void LocalConnection::sendQuery(
const auto & settings = context->getSettingsRef();
const char * begin = state->query.data();
const char * end = begin + state->query.size();
const Dialect & dialect = settings[Setting::dialect];
@ -675,5 +687,15 @@ ServerConnectionPtr LocalConnection::createConnection(
return std::make_unique<LocalConnection>(current_context, in, send_progress, send_profile_events, server_display_name);
}
ServerConnectionPtr LocalConnection::createConnection(
const ConnectionParameters &,
std::unique_ptr<Session> && session,
bool send_progress,
bool send_profile_events,
const String & server_display_name)
{
return std::make_unique<LocalConnection>(std::move(session), send_progress, send_profile_events, server_display_name);
}
}

View File

@ -76,6 +76,12 @@ public:
bool send_profile_events_,
const String & server_display_name_);
explicit LocalConnection(
std::unique_ptr<Session> && session_,
bool send_progress_ = false,
bool send_profile_events_ = false,
const String & server_display_name_ = "");
~LocalConnection() override;
IServerConnection::Type getConnectionType() const override { return IServerConnection::Type::LOCAL; }
@ -88,6 +94,13 @@ public:
bool send_profile_events = false,
const String & server_display_name = "");
static ServerConnectionPtr createConnection(
const ConnectionParameters & connection_parameters,
std::unique_ptr<Session> && session,
bool send_progress = false,
bool send_profile_events = false,
const String & server_display_name = "");
void setDefaultDatabase(const String & database) override;
void getServerVersion(const ConnectionTimeouts & timeouts,
@ -160,7 +173,7 @@ private:
bool needSendProgressOrMetrics();
ContextMutablePtr query_context;
Session session;
std::unique_ptr<Session> session;
bool send_progress;
bool send_profile_events;

View File

@ -36,6 +36,11 @@ Poco::AutoPtr<Poco::Util::AbstractConfiguration> clone(const Poco::Util::Abstrac
return res;
}
Poco::AutoPtr<Poco::Util::AbstractConfiguration> createEmpty()
{
return new Poco::Util::XMLConfiguration();
}
bool getBool(const Poco::Util::AbstractConfiguration & config, const std::string & key, bool default_, bool empty_as)
{
if (!config.has(key))

View File

@ -18,6 +18,8 @@ namespace DB::ConfigHelper
/// (i.e. items like "<test>value<child1/></test>").
Poco::AutoPtr<Poco::Util::AbstractConfiguration> clone(const Poco::Util::AbstractConfiguration & src);
Poco::AutoPtr<Poco::Util::AbstractConfiguration> createEmpty();
/// The behavior is like `config.getBool(key, default_)`,
/// except when the tag is empty (aka. self-closing), `empty_as` will be used instead of throwing Poco::Exception.
bool getBool(const Poco::Util::AbstractConfiguration & config, const std::string & key, bool default_ = false, bool empty_as = true);

View File

@ -622,6 +622,7 @@
M(1000, POCO_EXCEPTION) \
M(1001, STD_EXCEPTION) \
M(1002, UNKNOWN_EXCEPTION) \
M(1003, SSH_EXCEPTION) \
/* See END */
#ifdef APPLY_FOR_EXTERNAL_ERROR_CODES
@ -638,7 +639,7 @@ namespace ErrorCodes
APPLY_FOR_ERROR_CODES(M)
#undef M
constexpr ErrorCode END = 1002;
constexpr ErrorCode END = 1003;
ErrorPairHolder values[END + 1]{};
struct ErrorCodesNames

View File

@ -0,0 +1,61 @@
#include "config.h"
#include <Common/LibSSHInitializer.h>
#include <Common/Exception.h>
#if USE_SSH
#include <Common/clibssh.h>
namespace DB
{
namespace ErrorCodes
{
extern const int SSH_EXCEPTION;
}
}
namespace ssh
{
LibSSHInitializer & LibSSHInitializer::instance()
{
static LibSSHInitializer instance;
return instance;
}
LibSSHInitializer::LibSSHInitializer()
{
int rc = ssh_init();
if (rc != SSH_OK)
{
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to initialize libssh");
}
}
LibSSHInitializer::~LibSSHInitializer()
{
ssh_finalize();
}
}
#else
namespace ssh
{
LibSSHInitializer & LibSSHInitializer::instance()
{
static LibSSHInitializer instance;
return instance;
}
LibSSHInitializer::LibSSHInitializer() {}
LibSSHInitializer::~LibSSHInitializer() {}
}
#endif

View File

@ -0,0 +1,20 @@
#pragma once
namespace ssh
{
class LibSSHInitializer
{
public:
LibSSHInitializer(const LibSSHInitializer &) = delete;
LibSSHInitializer & operator=(const LibSSHInitializer &) = delete;
static LibSSHInitializer & instance();
~LibSSHInitializer();
private:
LibSSHInitializer(); // NOLINT
};
}

View File

@ -0,0 +1,62 @@
#include "config.h"
#if USE_SSH
# include <Common/logger_useful.h>
# include <Common/LibSSHLogger.h>
# include <Common/clibssh.h>
namespace ssh
{
namespace
{
void libssh_logger_callback(int priority, const char *, const char * buffer, void *)
{
Poco::Logger * logger = &Poco::Logger::get("LibSSH");
switch (priority)
{
case SSH_LOG_NOLOG:
break;
case SSH_LOG_WARNING:
LOG_WARNING(logger, "{}", buffer);
break;
case SSH_LOG_PROTOCOL:
case SSH_LOG_PACKET:
case SSH_LOG_FUNCTIONS:
LOG_TRACE(logger, "{}", buffer);
break;
}
}
}
namespace libsshLogger
{
void initialize()
{
ssh_set_log_callback(libssh_logger_callback);
ssh_set_log_level(SSH_LOG_FUNCTIONS); // Set the maximum log level
}
}
}
#else
namespace ssh
{
namespace libsshLogger
{
void initialize() {}
}
}
#endif

View File

@ -0,0 +1,8 @@
#pragma once
namespace ssh::libsshLogger
{
void initialize();
}

View File

@ -9,6 +9,7 @@
#include <iostream>
#include <mutex>
#include <queue>
#include <unistd.h>
#include <unordered_map>
#include <unordered_set>

View File

@ -37,9 +37,12 @@ public:
String getBase64() const;
String getKeyType() const;
friend class SSHKeyFactory;
private:
// private:
explicit SSHKey(ssh_key key_) : key(key_) { }
private:
ssh_key key = nullptr;
};

13
src/Common/clibssh.h Normal file
View File

@ -0,0 +1,13 @@
#pragma once
/*
Include this file to access libssh api.
*/
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wreserved-macro-identifier"
#pragma GCC diagnostic ignored "-Wreserved-identifier"
#pragma GCC diagnostic ignored "-Wdocumentation"
#include <libssh/libssh.h> // IWYU pragma: export
#include <libssh/server.h> // IWYU pragma: export
#include "libssh/callbacks.h" // IWYU pragma: export
#pragma GCC diagnostic pop

View File

@ -1254,6 +1254,7 @@ public:
{
SERVER, /// The program is run as clickhouse-server daemon (default behavior)
CLIENT, /// clickhouse-client
EMBEDDED_CLIENT,/// clickhouse-client being run over SSH tunnel
LOCAL, /// clickhouse-local
KEEPER, /// clickhouse-keeper (also daemon)
DISKS, /// clickhouse-disks

View File

@ -0,0 +1,193 @@
#include <Server/ClientEmbedded/ClientEmbedded.h>
#include <base/getFQDNOrHostName.h>
#include <Interpreters/Session.h>
#include <boost/algorithm/string/replace.hpp>
#include "Common/setThreadName.h"
#include <Common/Exception.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
namespace Setting
{
extern const SettingsUInt64 max_insert_block_size;
}
namespace
{
template<typename T>
T getEnvOption(const NameToNameMap & envVars, const String & key, T defaultValue)
{
auto it = envVars.find(key);
return it == envVars.end() ? defaultValue : parse<T>(it->second);
}
}
void ClientEmbedded::processError(const String &) const
{
if (ignore_error)
return;
if (is_interactive)
{
String message;
if (server_exception)
{
message = getExceptionMessage(*server_exception, print_stack_trace, true);
}
else if (client_exception)
{
message = client_exception->message();
}
error_stream << fmt::format("Received exception\n{}\n\n", message);
}
else
{
if (server_exception)
server_exception->rethrow();
if (client_exception)
client_exception->rethrow();
}
}
void ClientEmbedded::cleanup()
{
try
{
connection.reset();
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
void ClientEmbedded::connect()
{
if (!session)
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "Error creating connection without session object");
}
connection_parameters = ConnectionParameters::createForEmbedded(session->sessionContext()->getUserName(), default_database);
connection = LocalConnection::createConnection(
connection_parameters, std::move(session), need_render_progress, need_render_profile_events, server_display_name);
if (!default_database.empty())
{
connection->setDefaultDatabase(default_database);
}
}
Poco::Util::LayeredConfiguration & ClientEmbedded::getClientConfiguration()
{
chassert(layered_configuration);
return *layered_configuration;
}
int ClientEmbedded::run(const NameToNameMap & envVars, const String & first_query)
{
try
{
client_context = session->sessionContext();
initClientContext();
setThreadName("LocalServerPty");
query_processing_stage = QueryProcessingStage::Enum::Complete;
print_stack_trace = getEnvOption<bool>(envVars, "stacktrace", false);
output_stream << std::fixed << std::setprecision(3);
error_stream << std::fixed << std::setprecision(3);
is_interactive = stdin_is_a_tty;
static_query = first_query.empty() ? getEnvOption<String>(envVars, "query", "") : first_query;
delayed_interactive = is_interactive && !static_query.empty();
if (!is_interactive || delayed_interactive)
{
echo_queries = getEnvOption<bool>(envVars, "echo", false) || getEnvOption<bool>(envVars, "verbose", false);
ignore_error = getEnvOption<bool>(envVars, "ignore_error", false);
}
load_suggestions = (is_interactive || delayed_interactive) && !getEnvOption<bool>(envVars, "disable_suggestion", false);
if (load_suggestions)
{
suggestion_limit = getEnvOption<Int32>(envVars, "suggestion_limit", 10000);
}
enable_highlight = getEnvOption<bool>(envVars, "highlight", true);
multiline = getEnvOption<bool>(envVars, "multiline", false);
default_database = getEnvOption<String>(envVars, "database", "");
default_output_format = getEnvOption<String>(envVars, "output-format", getEnvOption<String>(envVars, "format", is_interactive ? "PrettyCompact" : "TSV"));
// TODO: Fix
// insert_format = "Values";
insert_format_max_block_size = getEnvOption<size_t>(envVars, "insert_format_max_block_size",
global_context->getSettingsRef()[Setting::max_insert_block_size]);
server_display_name = getEnvOption<String>(envVars, "display_name", getFQDNOrHostName());
prompt_by_server_display_name = getEnvOption<String>(envVars, "prompt_by_server_display_name", "{display_name} :) ");
std::map<String, String> prompt_substitutions{{"display_name", server_display_name}};
for (const auto & [key, value] : prompt_substitutions)
boost::replace_all(prompt_by_server_display_name, "{" + key + "}", value);
initTTYBuffer(toProgressOption(getEnvOption<String>(envVars, "progress", "default")),
toProgressOption(getEnvOption<String>(envVars, "progress-table", "default")));
if (is_interactive)
{
clearTerminal();
showClientVersion();
error_stream << std::endl;
}
connect();
if (is_interactive && !delayed_interactive)
{
runInteractive();
}
else
{
runNonInteractive();
if (delayed_interactive)
runInteractive();
}
cleanup();
return 0;
}
catch (const DB::Exception & e)
{
cleanup();
error_stream << getExceptionMessage(e, print_stack_trace, true) << std::endl;
return e.code() ? e.code() : -1;
}
catch (...)
{
cleanup();
error_stream << getCurrentExceptionMessage(false) << std::endl;
return getCurrentExceptionCode();
}
}
}

View File

@ -0,0 +1,77 @@
#pragma once
#include <Client/ClientBase.h>
#include <Client/LocalConnection.h>
#include <Core/Settings.h>
#include <Interpreters/Context.h>
#include <Loggers/Loggers.h>
#include <Common/InterruptListener.h>
#include <Common/StatusFile.h>
#include <Common/Config/ConfigHelper.h>
#include <Common/Config/ConfigProcessor.h>
#include <Poco/Util/LayeredConfiguration.h>
#include <filesystem>
#include <memory>
#include <optional>
namespace DB
{
// Client class which can be run embedded into server
class ClientEmbedded : public ClientBase
{
public:
explicit ClientEmbedded(
std::unique_ptr<Session> && session_,
int in_fd_,
int out_fd_,
int err_fd_,
std::istream & input_stream_,
std::ostream & output_stream_,
std::ostream & error_stream_)
: ClientBase(in_fd_, out_fd_, err_fd_, input_stream_, output_stream_, error_stream_), session(std::move(session_))
{
global_context = session->makeSessionContext();
configuration = ConfigHelper::createEmpty();
layered_configuration = new Poco::Util::LayeredConfiguration();
layered_configuration->add(configuration);
}
int run(const NameToNameMap & envVars, const String & first_query);
/// NOP
void setupSignalHandler() override {}
~ClientEmbedded() override { cleanup(); }
protected:
void connect() override;
Poco::Util::LayeredConfiguration & getClientConfiguration() override;
void processError(const String & query) const override;
String getName() const override { return "embedded"; }
void printHelpMessage(const OptionsDescription &, bool) override {}
void addOptions(OptionsDescription &) override {}
void processOptions(const OptionsDescription &,
const CommandLineOptions &,
const std::vector<Arguments> &,
const std::vector<Arguments> &) override {}
void processConfig() override {}
private:
void cleanup();
std::unique_ptr<Session> session;
ConfigurationPtr configuration;
Poco::AutoPtr<Poco::Util::LayeredConfiguration> layered_configuration;
};
}

View File

@ -0,0 +1,64 @@
#include <Server/ClientEmbedded/ClientEmbeddedRunner.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
void ClientEmbeddedRunner::run(const NameToNameMap & envs, const String & starting_query)
{
if (started.test_and_set())
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "Client has been already started");
}
LOG_DEBUG(log, "Starting client");
client_thread = ThreadFromGlobalPool(&ClientEmbeddedRunner::clientRoutine, this, envs, starting_query);
}
void ClientEmbeddedRunner::changeWindowSize(int width, int height, int width_pixels, int height_pixels)
{
auto * pty_descriptors = dynamic_cast<PtyClientDescriptorSet *>(client_descriptors.get());
if (pty_descriptors == nullptr)
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "Accessing window change on non pty descriptors");
}
pty_descriptors->changeWindowSize(width, height, width_pixels, height_pixels);
}
ClientEmbeddedRunner::~ClientEmbeddedRunner()
{
LOG_DEBUG(log, "Closing server descriptors and waiting for client to finish");
client_descriptors->closeServerDescriptors(); // May throw if something bad happens to descriptors, which will call std::terminate
if (client_thread.joinable())
{
client_thread.join();
}
LOG_DEBUG(log, "Client has finished");
}
void ClientEmbeddedRunner::clientRoutine(NameToNameMap envs, String starting_query)
{
try
{
auto descr = client_descriptors->getDescriptorsForClient();
auto stre = client_descriptors->getStreamsForClient();
ClientEmbedded client(std::move(db_session), descr.in, descr.out, descr.err, stre.in, stre.out, stre.err);
client.run(envs, starting_query);
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
finished.test_and_set();
char c = 0;
// Server may poll on a descriptor waiting for client output, wake him up with invisible character
write(client_descriptors->getDescriptorsForClient().out, &c, 1);
}
}

View File

@ -0,0 +1,52 @@
#pragma once
#include <stdexcept>
#include <Server/ClientEmbedded/ClientEmbedded.h>
#include <Server/ClientEmbedded/IClientDescriptorSet.h>
#include <Server/ClientEmbedded/PipeClientDescriptorSet.h>
#include <Server/ClientEmbedded/PtyClientDescriptorSet.h>
#include <Common/ThreadPool.h>
namespace DB
{
// Runs embedded client in dedicated thread, passes descriptors, checks its state
class ClientEmbeddedRunner
{
public:
bool hasStarted() { return started.test(); }
bool hasFinished() { return finished.test(); }
// void stopQuery() { client.stopQuery(); } // this is save for client until he uses thread-safe structures to handle query stopping
void run(const NameToNameMap & envs, const String & starting_query = "");
IClientDescriptorSet::DescriptorSet getDescriptorsForServer() { return client_descriptors->getDescriptorsForServer(); }
bool hasPty() const { return client_descriptors->isPty(); }
// Sets new window size for tty. Works only if IClientDescriptorSet is pty
void changeWindowSize(int width, int height, int width_pixels, int height_pixels);
~ClientEmbeddedRunner();
explicit ClientEmbeddedRunner(std::unique_ptr<IClientDescriptorSet> && client_descriptor_, std::unique_ptr<Session> && dbSession_)
: client_descriptors(std::move(client_descriptor_)), db_session(std::move(dbSession_)), log(&Poco::Logger::get("ClientEmbeddedRunner"))
{
}
private:
void clientRoutine(NameToNameMap envs, String starting_query);
std::unique_ptr<IClientDescriptorSet>
client_descriptors; // This is used by server thread and client thread, be sure that server only gets them via getDescriptorsForServer
std::atomic_flag started = ATOMIC_FLAG_INIT;
std::atomic_flag finished = ATOMIC_FLAG_INIT;
ThreadFromGlobalPool client_thread;
std::unique_ptr<Session> db_session;
Poco::Logger * log;
};
}

View File

@ -0,0 +1,39 @@
#pragma once
#include <iostream>
namespace DB
{
// Interface, which handles descriptor pairs, that could be attached to embedded client
class IClientDescriptorSet
{
public:
struct DescriptorSet
{
int in = -1;
int out = -1;
int err = -1;
};
struct StreamSet
{
std::istream & in;
std::ostream & out;
std::ostream & err;
};
virtual DescriptorSet getDescriptorsForClient() = 0;
virtual DescriptorSet getDescriptorsForServer() = 0;
virtual StreamSet getStreamsForClient() = 0;
virtual bool isPty() const = 0;
virtual void closeServerDescriptors() = 0;
virtual ~IClientDescriptorSet() = default;
};
}

View File

@ -0,0 +1,64 @@
#pragma once
#include <Server/ClientEmbedded/IClientDescriptorSet.h>
#include <boost/iostreams/device/file_descriptor.hpp>
#include <boost/iostreams/stream.hpp>
#include <Poco/Pipe.h>
namespace DB
{
class PipeClientDescriptorSet : public IClientDescriptorSet
{
public:
PipeClientDescriptorSet()
: fd_source(pipe_in.readHandle(), boost::iostreams::never_close_handle)
, fd_sink(pipe_out.writeHandle(), boost::iostreams::never_close_handle)
, fd_sink_err(pipe_err.writeHandle(), boost::iostreams::never_close_handle)
, input_stream(fd_source)
, output_stream(fd_sink)
, output_stream_err(fd_sink_err)
{
output_stream << std::unitbuf;
output_stream_err << std::unitbuf;
}
DescriptorSet getDescriptorsForClient() override
{
return DescriptorSet{.in = pipe_in.readHandle(), .out = pipe_out.writeHandle(), .err = pipe_err.writeHandle()};
}
DescriptorSet getDescriptorsForServer() override
{
return DescriptorSet{.in = pipe_in.writeHandle(), .out = pipe_out.readHandle(), .err = pipe_err.readHandle()};
}
StreamSet getStreamsForClient() override { return StreamSet{.in = input_stream, .out = output_stream, .err = output_stream_err}; }
void closeServerDescriptors() override
{
pipe_in.close(Poco::Pipe::CLOSE_WRITE);
pipe_out.close(Poco::Pipe::CLOSE_READ);
pipe_err.close(Poco::Pipe::CLOSE_READ);
}
bool isPty() const override { return false; }
~PipeClientDescriptorSet() override = default;
private:
Poco::Pipe pipe_in;
Poco::Pipe pipe_out;
Poco::Pipe pipe_err;
// Provide streams on top of file descriptors
boost::iostreams::file_descriptor_source fd_source;
boost::iostreams::file_descriptor_sink fd_sink;
boost::iostreams::file_descriptor_sink fd_sink_err;
boost::iostreams::stream<boost::iostreams::file_descriptor_source> input_stream;
boost::iostreams::stream<boost::iostreams::file_descriptor_sink> output_stream;
boost::iostreams::stream<boost::iostreams::file_descriptor_sink> output_stream_err;
};
}

View File

@ -0,0 +1,77 @@
#include <Server/ClientEmbedded/PtyClientDescriptorSet.h>
#include <Common/Exception.h>
#include "pty.h"
namespace DB
{
namespace ErrorCodes
{
extern const int SYSTEM_ERROR;
}
void PtyClientDescriptorSet::FileDescriptorWrapper::close()
{
if (fd != -1)
{
if (::close(fd) != 0 && errno != EINTR)
throw ErrnoException(ErrorCodes::SYSTEM_ERROR, "Unexpected error while closing file descriptor");
}
fd = -1;
}
PtyClientDescriptorSet::PtyClientDescriptorSet(const String & term_name_, int width, int height, int width_pixels, int height_pixels)
: term_name(term_name_)
{
winsize winsize{};
winsize.ws_col = width;
winsize.ws_row = height;
winsize.ws_xpixel = width_pixels;
winsize.ws_ypixel = height_pixels;
int pty_master_raw = -1, pty_slave_raw = -1;
if (openpty(&pty_master_raw, &pty_slave_raw, nullptr, nullptr, &winsize) != 0)
{
throw ErrnoException(ErrorCodes::SYSTEM_ERROR, "Cannot open pty");
}
pty_master.capture(pty_master_raw);
pty_slave.capture(pty_slave_raw);
fd_source.open(pty_slave.get(), boost::iostreams::never_close_handle);
fd_sink.open(pty_slave.get(), boost::iostreams::never_close_handle);
// disable signals from tty
struct termios tios;
if (tcgetattr(pty_slave.get(), &tios) == -1)
{
throw ErrnoException(ErrorCodes::SYSTEM_ERROR, "Cannot get termios from tty via tcgetattr");
}
tios.c_lflag &= ~ISIG;
if (tcsetattr(pty_slave.get(), TCSANOW, &tios) == -1)
{
throw ErrnoException(ErrorCodes::SYSTEM_ERROR, "Cannot set termios to tty via tcsetattr");
}
input_stream.open(fd_source);
output_stream.open(fd_sink);
output_stream << std::unitbuf;
}
void PtyClientDescriptorSet::changeWindowSize(int width, int height, int width_pixels, int height_pixels) const
{
winsize winsize{};
winsize.ws_col = width;
winsize.ws_row = height;
winsize.ws_xpixel = width_pixels;
winsize.ws_ypixel = height_pixels;
if (ioctl(pty_master.get(), TIOCSWINSZ, &winsize) == -1)
{
throw ErrnoException(ErrorCodes::SYSTEM_ERROR, "Cannot update terminal window size via ioctl TIOCSWINSZ");
}
}
PtyClientDescriptorSet::~PtyClientDescriptorSet() = default;
}

View File

@ -0,0 +1,68 @@
#pragma once
#include <Server/ClientEmbedded/IClientDescriptorSet.h>
#include <boost/iostreams/device/file_descriptor.hpp>
#include <boost/iostreams/stream.hpp>
#include <Poco/Pipe.h>
#include "base/types.h"
namespace DB
{
class PtyClientDescriptorSet : public IClientDescriptorSet
{
public:
PtyClientDescriptorSet(const String & term_name, int width, int height, int width_pixels, int height_pixels);
DescriptorSet getDescriptorsForClient() override
{
return DescriptorSet{.in = pty_slave.get(), .out = pty_slave.get(), .err = pty_slave.get()};
}
DescriptorSet getDescriptorsForServer() override { return DescriptorSet{.in = pty_master.get(), .out = pty_master.get(), .err = -1}; }
StreamSet getStreamsForClient() override { return StreamSet{.in = input_stream, .out = output_stream, .err = output_stream}; }
void changeWindowSize(int width, int height, int width_pixels, int height_pixels) const;
void closeServerDescriptors() override { pty_master.close(); }
bool isPty() const override { return true; }
~PtyClientDescriptorSet() override;
private:
class FileDescriptorWrapper
{
public:
FileDescriptorWrapper() = default;
void capture(int fd_)
{
close();
fd = fd_;
}
int get() const { return fd; }
void close();
~FileDescriptorWrapper() { close(); } // may throw, thus std::terminate
private:
int fd = -1;
};
String term_name;
FileDescriptorWrapper pty_master;
FileDescriptorWrapper pty_slave;
// Provide streams on top of file descriptors
boost::iostreams::file_descriptor_source fd_source; // handles pty_slave lifetime
boost::iostreams::file_descriptor_sink fd_sink;
boost::iostreams::stream<boost::iostreams::file_descriptor_source> input_stream;
boost::iostreams::stream<boost::iostreams::file_descriptor_sink> output_stream;
};
}

View File

@ -0,0 +1,79 @@
#include <Server/SSH/SSHBind.h>
#if USE_SSH
#include <stdexcept>
#include <fmt/format.h>
#include <Common/Exception.h>
#include <Common/clibssh.h>
namespace DB
{
namespace ErrorCodes
{
extern const int SSH_EXCEPTION;
}
}
namespace ssh
{
SSHBind::SSHBind() : bind(ssh_bind_new(), &deleter)
{
if (!bind)
{
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to create ssh_bind");
}
}
SSHBind::~SSHBind() = default;
SSHBind::SSHBind(SSHBind && other) noexcept = default;
SSHBind & SSHBind::operator=(SSHBind && other) noexcept = default;
void SSHBind::setHostKey(const std::string & key_path)
{
if (ssh_bind_options_set(bind.get(), SSH_BIND_OPTIONS_HOSTKEY, key_path.c_str()) != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed setting host key in sshbind due to {}", getError());
}
String SSHBind::getError()
{
return String(ssh_get_error(bind.get()));
}
void SSHBind::disableDefaultConfig()
{
bool enable = false;
if (ssh_bind_options_set(bind.get(), SSH_BIND_OPTIONS_PROCESS_CONFIG, &enable) != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed disabling default config in sshbind due to {}", getError());
}
void SSHBind::setFd(int fd)
{
ssh_bind_set_fd(bind.get(), fd);
}
void SSHBind::listen()
{
if (ssh_bind_listen(bind.get()) != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed listening in sshbind due to {}", getError());
}
void SSHBind::acceptFd(SSHSession & session, int fd)
{
if (ssh_bind_accept_fd(bind.get(), session.getInternalPtr(), fd) != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed accepting fd in sshbind due to {}", getError());
}
void SSHBind::deleter(BindPtr bind)
{
ssh_bind_free(bind);
}
}
#endif

54
src/Server/SSH/SSHBind.h Normal file
View File

@ -0,0 +1,54 @@
#pragma once
#include "config.h"
#if USE_SSH
#include <cstdint>
#include <memory>
#include <string>
#include <base/types.h>
#include <Server/SSH/SSHSession.h>
struct ssh_bind_struct;
namespace ssh
{
// Wrapper around libssh's ssh_bind
class SSHBind
{
public:
using BindPtr = ssh_bind_struct *;
SSHBind();
~SSHBind();
SSHBind(const SSHBind &) = delete;
SSHBind & operator=(const SSHBind &) = delete;
SSHBind(SSHBind &&) noexcept;
SSHBind & operator=(SSHBind &&) noexcept;
// Disables libssh's default config
void disableDefaultConfig();
// Sets host key for a server. It can be set only one time for each key type.
// If you provide different keys of one type, the first one will be overwritten.
void setHostKey(const std::string & key_path);
// Passes external socket to ssh_bind
void setFd(int fd);
// Listens on a socket. If it was passed via setFd just read hostkeys
void listen();
// Assign accepted socket to ssh_session
void acceptFd(SSHSession & session, int fd);
String getError();
private:
static void deleter(BindPtr bind);
std::unique_ptr<ssh_bind_struct, decltype(&deleter)> bind;
};
}
#endif

View File

@ -0,0 +1,87 @@
#include <Server/SSH/SSHChannel.h>
#if USE_SSH
#include <stdexcept>
#include <Common/Exception.h>
#include <Common/clibssh.h>
namespace DB
{
namespace ErrorCodes
{
extern const int SSH_EXCEPTION;
}
}
namespace ssh
{
SSHChannel::SSHChannel(SSHSession::SessionPtr session) : channel(ssh_channel_new(session), &deleter)
{
if (!channel)
{
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to create ssh_channel");
}
}
SSHChannel::~SSHChannel() = default;
SSHChannel::SSHChannel(SSHChannel && other) noexcept : channel(std::move(other.channel))
{
}
SSHChannel & SSHChannel::operator=(SSHChannel && other) noexcept
{
if (this != &other)
{
channel = std::move(other.channel);
}
return *this;
}
ssh_channel SSHChannel::getCChannelPtr() const
{
return channel.get();
}
int SSHChannel::read(void * dest, uint32_t count, int isStderr)
{
return ssh_channel_read(channel.get(), dest, count, isStderr);
}
int SSHChannel::readTimeout(void * dest, uint32_t count, int isStderr, int timeout)
{
return ssh_channel_read_timeout(channel.get(), dest, count, isStderr, timeout);
}
int SSHChannel::write(const void * data, uint32_t len)
{
return ssh_channel_write(channel.get(), data, len);
}
int SSHChannel::sendEof()
{
return ssh_channel_send_eof(channel.get());
}
int SSHChannel::close()
{
return ssh_channel_close(channel.get());
}
bool SSHChannel::isOpen()
{
return ssh_channel_is_open(channel.get()) != 0;
}
void SSHChannel::deleter(ssh_channel ch)
{
ssh_channel_free(ch);
}
}
#endif

View File

@ -0,0 +1,50 @@
#pragma once
#include "config.h"
#if USE_SSH
#include <memory>
#include <Server/SSH/SSHSession.h>
struct ssh_channel_struct;
namespace ssh
{
// Wrapper around libssh's ssh_channel
class SSHChannel
{
public:
using ChannelPtr = ssh_channel_struct *;
explicit SSHChannel(SSHSession::SessionPtr session);
~SSHChannel();
SSHChannel(const SSHChannel &) = delete;
SSHChannel & operator=(const SSHChannel &) = delete;
SSHChannel(SSHChannel &&) noexcept;
SSHChannel & operator=(SSHChannel &&) noexcept;
// Exposes ssh_channel c pointer, which could be used to be passed into other objects
ChannelPtr getCChannelPtr() const;
int read(void * dest, uint32_t count, int isStderr);
int readTimeout(void * dest, uint32_t count, int isStderr, int timeout);
int write(const void * data, uint32_t len);
// Send eof signal to the other side of channel. It does not close the socket.
int sendEof();
// Sends eof if it has not been sent and then closes channel.
int close();
bool isOpen();
private:
static void deleter(ChannelPtr ch);
std::unique_ptr<ssh_channel_struct, decltype(&deleter)> channel;
};
}
#endif

View File

@ -0,0 +1,93 @@
#include <Server/SSH/SSHEvent.h>
#if USE_SSH
#include <stdexcept>
#include <Common/Exception.h>
#include <Common/clibssh.h>
namespace DB
{
namespace ErrorCodes
{
extern const int SSH_EXCEPTION;
}
}
namespace ssh
{
SSHEvent::SSHEvent() : event(ssh_event_new(), &deleter)
{
if (!event)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to create ssh_event");
}
SSHEvent::~SSHEvent() = default;
SSHEvent::SSHEvent(SSHEvent && other) noexcept : event(std::move(other.event))
{
}
SSHEvent & SSHEvent::operator=(SSHEvent && other) noexcept
{
if (this != &other)
{
event = std::move(other.event);
}
return *this;
}
void SSHEvent::addSession(SSHSession & session)
{
if (ssh_event_add_session(event.get(), session.getInternalPtr()) != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Error adding session to ssh event");
}
void SSHEvent::removeSession(SSHSession & session)
{
if (ssh_event_remove_session(event.get(), session.getInternalPtr()) != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Error removing session from ssh event");
}
int SSHEvent::poll(int timeout)
{
while (true)
{
int rc = ssh_event_dopoll(event.get(), timeout);
if (rc == SSH_AGAIN)
continue;
if (rc != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Error on polling on ssh event: {}", rc);
return rc;
}
}
int SSHEvent::poll()
{
return poll(-1);
}
void SSHEvent::addFd(int fd, int events, EventCallback cb, void * userdata)
{
if (ssh_event_add_fd(event.get(), fd, events, cb, userdata) != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Error on adding custom file descriptor to ssh event");
}
void SSHEvent::removeFd(socket_t fd)
{
ssh_event_remove_fd(event.get(), fd);
}
void SSHEvent::deleter(EventPtr e)
{
ssh_event_free(e);
}
}
#endif

48
src/Server/SSH/SSHEvent.h Normal file
View File

@ -0,0 +1,48 @@
#pragma once
#include "config.h"
#if USE_SSH
#include <memory>
#include <Server/SSH/SSHSession.h>
struct ssh_event_struct;
namespace ssh
{
// Wrapper around ssh_event from libssh
class SSHEvent
{
public:
using EventPtr = ssh_event_struct *;
using EventCallback = int (*)(int fd, int revents, void * userdata);
SSHEvent();
~SSHEvent();
SSHEvent(const SSHEvent &) = delete;
SSHEvent & operator=(const SSHEvent &) = delete;
SSHEvent(SSHEvent &&) noexcept;
SSHEvent & operator=(SSHEvent &&) noexcept;
// Adds session's socket to event. Now callbacks will be executed for this session on poll
void addSession(SSHSession & session);
void removeSession(SSHSession & session);
// Add fd to ssh_event and assign callbacks on fd event
void addFd(int fd, int events, EventCallback cb, void * userdata);
void removeFd(int fd);
int poll(int timeout);
int poll();
private:
static void deleter(EventPtr e);
std::unique_ptr<ssh_event_struct, decltype(&deleter)> event;
};
}
#endif

View File

@ -0,0 +1,519 @@
#include <Server/SSH/SSHPtyHandler.h>
#if USE_SSH
#include <Access/Common/AuthenticationType.h>
#include <Access/Credentials.h>
#include <Access/SSH/SSHPublicKey.h>
#include <Common/clibssh.h>
#include <Common/logger_useful.h>
#include <Core/Names.h>
#include <Poco/Net/StreamSocket.h>
#include <Poco/Pipe.h>
#include <Server/ClientEmbedded/ClientEmbeddedRunner.h>
#include <Server/ClientEmbedded/IClientDescriptorSet.h>
#include <Server/ClientEmbedded/PtyClientDescriptorSet.h>
#include <Server/SSH/SSHChannel.h>
#include <Server/SSH/SSHEvent.h>
#include <sys/poll.h>
#include <atomic>
#include <stdexcept>
#include <boost/iostreams/device/file_descriptor.hpp>
#include <boost/iostreams/stream.hpp>
namespace
{
/*
Need to generate adapter functions, such for each member function, for example:
class SessionCallback
{
For this:
ssh_channel channelOpen(ssh_session session)
{
channel = SSHChannel(session);
return channel->get();
}
Generate this:
static ssh_channel channelOpenAdapter(ssh_session session, void * userdata)
{
auto * self = static_cast<SessionCallback*>(userdata);
return self->channel_open;
}
}
We just static cast userdata to our class and then call member function.
This is needed to use c++ classes in libssh callbacks.
Maybe there is a better way? Or just write boilerplate code and avoid macros?
*/
#define GENERATE_ADAPTER_FUNCTION(class, func_name, return_type) \
template <typename... Args> \
static return_type func_name##Adapter(Args... args, void * userdata) \
{ \
auto * self = static_cast<class *>(userdata); \
return self->func_name(args...); \
}
}
namespace DB
{
namespace
{
// Wrapper around ssh_channel_callbacks. Each callback must not throw any exceptions, as c code is executed
class ChannelCallback
{
public:
using DescriptorSet = IClientDescriptorSet::DescriptorSet;
explicit ChannelCallback(::ssh::SSHChannel && channel_, std::unique_ptr<Session> && dbSession_)
: channel(std::move(channel_)), db_session(std::move(dbSession_)), log(&Poco::Logger::get("SSHChannelCallback"))
{
channel_cb.userdata = this;
channel_cb.channel_pty_request_function = ptyRequestAdapter<ssh_session, ssh_channel, const char *, int, int, int, int>;
channel_cb.channel_shell_request_function = shellRequestAdapter<ssh_session, ssh_channel>;
channel_cb.channel_data_function = dataFunctionAdapter<ssh_session, ssh_channel, void *, uint32_t, int>;
channel_cb.channel_pty_window_change_function = ptyResizeAdapter<ssh_session, ssh_channel, int, int, int, int>;
channel_cb.channel_env_request_function = envRequestAdapter<ssh_session, ssh_channel, const char *, const char*>;
channel_cb.channel_exec_request_function = execRequestAdapter<ssh_session, ssh_channel, const char *>;
channel_cb.channel_subsystem_request_function = subsystemRequestAdapter<ssh_session, ssh_channel, const char *>;
ssh_callbacks_init(&channel_cb) ssh_set_channel_callbacks(channel.getCChannelPtr(), &channel_cb);
}
bool hasClientFinished() { return client_runner.has_value() && client_runner->hasFinished(); }
DescriptorSet client_input_output;
::ssh::SSHChannel channel;
std::unique_ptr<Session> db_session;
NameToNameMap env;
std::optional<ClientEmbeddedRunner> client_runner;
Poco::Logger * log;
private:
int ptyRequest(ssh_session, ssh_channel, const char * term, int width, int height, int width_pixels, int height_pixels) noexcept
{
LOG_TRACE(log, "Received pty request");
if (!db_session || client_runner.has_value())
return SSH_ERROR;
try
{
auto client_descriptors = std::make_unique<PtyClientDescriptorSet>(String(term), width, height, width_pixels, height_pixels);
client_runner.emplace(std::move(client_descriptors), std::move(db_session));
}
catch (...)
{
tryLogCurrentException(log, "Exception from creating pty");
return SSH_ERROR;
}
return SSH_OK;
}
GENERATE_ADAPTER_FUNCTION(ChannelCallback, ptyRequest, int)
int ptyResize(ssh_session, ssh_channel, int width, int height, int width_pixels, int height_pixels) noexcept
{
LOG_TRACE(log, "Received pty resize");
if (!client_runner.has_value() || !client_runner->hasPty())
{
return SSH_ERROR;
}
try
{
client_runner->changeWindowSize(width, height, width_pixels, height_pixels);
return SSH_OK;
}
catch (...)
{
tryLogCurrentException(log, "Exception from changing window size");
return SSH_ERROR;
}
}
GENERATE_ADAPTER_FUNCTION(ChannelCallback, ptyResize, int)
int dataFunction(ssh_session, ssh_channel, void * data, uint32_t len, int /*is_stderr*/) const noexcept
{
if (len == 0 || client_input_output.in == -1)
{
return 0;
}
return static_cast<int>(write(client_input_output.in, data, len));
}
GENERATE_ADAPTER_FUNCTION(ChannelCallback, dataFunction, int)
int subsystemRequest(ssh_session, ssh_channel, const char * subsystem) noexcept
{
LOG_TRACE(log, "Received subsystem request");
if (strcmp(subsystem, "ch-client") != 0)
{
return SSH_ERROR;
}
LOG_TRACE(log, "Subsystem is supported");
if (!client_runner.has_value() || client_runner->hasStarted() || !client_runner->hasPty())
{
return SSH_ERROR;
}
try
{
client_runner->run(env);
client_input_output = client_runner->getDescriptorsForServer();
return SSH_OK;
}
catch (...)
{
tryLogCurrentException(log, "Exception from starting client");
return SSH_ERROR;
}
}
GENERATE_ADAPTER_FUNCTION(ChannelCallback, subsystemRequest, int)
int shellRequest(ssh_session, ssh_channel) noexcept
{
LOG_TRACE(log, "Received shell request");
if (!client_runner.has_value() || client_runner->hasStarted() || !client_runner->hasPty())
{
return SSH_ERROR;
}
try
{
client_runner->run(env);
client_input_output = client_runner->getDescriptorsForServer();
return SSH_OK;
}
catch (...)
{
tryLogCurrentException(log, "Exception from starting client");
return SSH_ERROR;
}
}
GENERATE_ADAPTER_FUNCTION(ChannelCallback, shellRequest, int)
int envRequest(ssh_session, ssh_channel, const char * env_name, const char * env_value)
{
LOG_TRACE(log, "Received env request");
env[env_name] = env_value;
return SSH_OK;
}
GENERATE_ADAPTER_FUNCTION(ChannelCallback, envRequest, int)
int execNopty(const String & command)
{
if (db_session)
{
try
{
auto client_descriptors = std::make_unique<PipeClientDescriptorSet>();
client_runner.emplace(std::move(client_descriptors), std::move(db_session));
client_runner->run(env, command);
client_input_output = client_runner->getDescriptorsForServer();
}
catch (...)
{
tryLogCurrentException(log, "Exception from starting client with no pty");
return SSH_ERROR;
}
}
return SSH_OK;
}
int execRequest(ssh_session, ssh_channel, const char * command)
{
LOG_TRACE(log, "Received exec request");
if (client_runner.has_value() && (client_runner->hasStarted() || !client_runner->hasPty()))
{
return SSH_ERROR;
}
if (client_runner.has_value())
{
try
{
client_runner->run(env, command);
client_input_output = client_runner->getDescriptorsForServer();
return SSH_OK;
}
catch (...)
{
tryLogCurrentException(log, "Exception from starting client with pre entered query");
return SSH_ERROR;
}
}
return execNopty(String(command));
}
GENERATE_ADAPTER_FUNCTION(ChannelCallback, execRequest, int)
ssh_channel_callbacks_struct channel_cb = {};
};
int process_stdout(socket_t fd, int revents, void * userdata)
{
char buf[1024];
int n = -1;
ssh_channel channel = static_cast<ssh_channel>(userdata);
if (channel != nullptr && (revents & POLLIN) != 0)
{
n = static_cast<int>(read(fd, buf, 1024));
if (n > 0)
{
ssh_channel_write(channel, buf, n);
}
}
return n;
}
int process_stderr(socket_t fd, int revents, void * userdata)
{
char buf[1024];
int n = -1;
ssh_channel channel = static_cast<ssh_channel>(userdata);
if (channel != nullptr && (revents & POLLIN) != 0)
{
n = static_cast<int>(read(fd, buf, 1024));
if (n > 0)
{
ssh_channel_write_stderr(channel, buf, n);
}
}
return n;
}
// Wrapper around ssh_server_callbacks. Each callback must not throw any exceptions, as c code is executed
class SessionCallback
{
public:
explicit SessionCallback(::ssh::SSHSession & session, IServer & server, const Poco::Net::SocketAddress & address_)
: server_context(server.context()), peer_address(address_), log(&Poco::Logger::get("SSHSessionCallback"))
{
server_cb.userdata = this;
server_cb.auth_password_function = authPasswordAdapter<ssh_session, const char*, const char*>;
server_cb.auth_pubkey_function = authPublickeyAdapter<ssh_session, const char *, ssh_key, char>;
ssh_set_auth_methods(session.getInternalPtr(), SSH_AUTH_METHOD_PASSWORD | SSH_AUTH_METHOD_PUBLICKEY);
server_cb.channel_open_request_session_function = channelOpenAdapter<ssh_session>;
ssh_callbacks_init(&server_cb) ssh_set_server_callbacks(session.getInternalPtr(), &server_cb);
}
size_t auth_attempts = 0;
bool authenticated = false;
std::unique_ptr<Session> db_session;
DB::ContextMutablePtr server_context;
Poco::Net::SocketAddress peer_address;
std::unique_ptr<ChannelCallback> channel_callback;
Poco::Logger * log;
private:
ssh_channel channelOpen(ssh_session session) noexcept
{
LOG_DEBUG(log, "Opening a channel");
if (!db_session)
{
return nullptr;
}
try
{
auto channel = ::ssh::SSHChannel(session);
channel_callback = std::make_unique<ChannelCallback>(std::move(channel), std::move(db_session));
return channel_callback->channel.getCChannelPtr();
}
catch (...)
{
tryLogCurrentException(log, "Error while opening channel:");
return nullptr;
}
}
GENERATE_ADAPTER_FUNCTION(SessionCallback, channelOpen, ssh_channel)
int authPassword(ssh_session, const char * user, const char * pass) noexcept
{
try
{
LOG_TRACE(log, "Authenticating with password");
auto db_session_created = std::make_unique<Session>(server_context, ClientInfo::Interface::LOCAL);
String user_name(user), password(pass);
db_session_created->authenticate(user_name, password, peer_address);
authenticated = true;
db_session = std::move(db_session_created);
return SSH_AUTH_SUCCESS;
}
catch (...)
{
++auth_attempts;
return SSH_AUTH_DENIED;
}
}
GENERATE_ADAPTER_FUNCTION(SessionCallback, authPassword, int)
int authPublickey(ssh_session, const char * user, ssh_key key, char signature_state) noexcept
{
try
{
LOG_TRACE(log, "Authenticating with public key");
auto db_session_created = std::make_unique<Session>(server_context, ClientInfo::Interface::LOCAL);
String user_name(user);
if (signature_state == SSH_PUBLICKEY_STATE_NONE)
{
// This is the case when user wants to check if he is able to use this type of authentication.
// Also here we may check if the key is associated with the user, but current session
// authentication mechanism doesn't support it.
const auto user_authentication_types = db_session_created->getAuthenticationTypes(user_name);
for (auto user_authentication_type : user_authentication_types)
if (user_authentication_type == AuthenticationType::SSH_KEY)
return SSH_AUTH_DENIED;
}
if (signature_state != SSH_PUBLICKEY_STATE_VALID)
{
++auth_attempts;
return SSH_AUTH_DENIED;
}
/// FIXME: Generate random string
String challenge="Hello...";
// The signature is checked, so just verify that user is associated with publickey.
// Function will throw if authentication fails.
db_session_created->authenticate(SshCredentials{user_name, SSHKey(key).signString(challenge), challenge}, peer_address);
authenticated = true;
db_session = std::move(db_session_created);
return SSH_AUTH_SUCCESS;
}
catch (...)
{
++auth_attempts;
return SSH_AUTH_DENIED;
}
}
GENERATE_ADAPTER_FUNCTION(SessionCallback, authPublickey, int)
ssh_server_callbacks_struct server_cb = {};
};
}
SSHPtyHandler::SSHPtyHandler(
IServer & server_,
::ssh::SSHSession session_,
const Poco::Net::StreamSocket & socket,
unsigned int max_auth_attempts_,
unsigned int auth_timeout_seconds_,
unsigned int finish_timeout_seconds_,
unsigned int event_poll_interval_milliseconds_)
: Poco::Net::TCPServerConnection(socket)
, server(server_)
, log(&Poco::Logger::get("SSHPtyHandler"))
, session(std::move(session_))
, max_auth_attempts(max_auth_attempts_)
, auth_timeout_seconds(auth_timeout_seconds_)
, finish_timeout_seconds(finish_timeout_seconds_)
, event_poll_interval_milliseconds(event_poll_interval_milliseconds_)
{
}
void SSHPtyHandler::run()
{
::ssh::SSHEvent event;
SessionCallback sdata(session, server, socket().peerAddress());
session.handleKeyExchange();
event.addSession(session);
int max_iterations = auth_timeout_seconds * 1000 / event_poll_interval_milliseconds;
int n = 0;
while (!sdata.authenticated || !sdata.channel_callback)
{
/* If the user has used up all attempts, or if he hasn't been able to
* authenticate in auth_timeout_seconds, disconnect. */
if (sdata.auth_attempts >= max_auth_attempts || n >= max_iterations)
{
return;
}
if (server.isCancelled())
{
return;
}
event.poll(event_poll_interval_milliseconds);
n++;
}
bool fds_set = false;
do
{
/* Poll the main event which takes care of the session, the channel and
* even our client's stdout/stderr (once it's started). */
event.poll(event_poll_interval_milliseconds);
/* If client's stdout/stderr has been registered with the event,
* or the client hasn't started yet, continue. */
if (fds_set || sdata.channel_callback->client_input_output.out == -1)
{
continue;
}
/* Executed only once, once the client starts. */
fds_set = true;
/* If stdout valid, add stdout to be monitored by the poll event. */
if (sdata.channel_callback->client_input_output.out != -1)
{
event.addFd(sdata.channel_callback->client_input_output.out, POLLIN, process_stdout, sdata.channel_callback->channel.getCChannelPtr());
}
if (sdata.channel_callback->client_input_output.err != -1)
{
event.addFd(sdata.channel_callback->client_input_output.err, POLLIN, process_stderr, sdata.channel_callback->channel.getCChannelPtr());
}
} while (sdata.channel_callback->channel.isOpen() && !sdata.channel_callback->hasClientFinished() && !server.isCancelled());
LOG_DEBUG(
log,
"Finishing connection with state: channel open: {}, embedded client finished: {}, server cancelled: {}",
sdata.channel_callback->channel.isOpen(), sdata.channel_callback->hasClientFinished(), server.isCancelled()
);
event.removeFd(sdata.channel_callback->client_input_output.out);
event.removeFd(sdata.channel_callback->client_input_output.err);
sdata.channel_callback->channel.sendEof();
sdata.channel_callback->channel.close();
/* Wait up to finish_timeout_seconds seconds for the client to terminate the session. */
max_iterations = finish_timeout_seconds * 1000 / event_poll_interval_milliseconds;
for (n = 0; n < max_iterations && !session.hasFinished(); n++)
{
event.poll(event_poll_interval_milliseconds);
}
LOG_DEBUG(log, "Connection closed");
}
}
#endif

View File

@ -0,0 +1,43 @@
#pragma once
#include "config.h"
#if USE_SSH
#include <Poco/Net/StreamSocket.h>
#include <Poco/Net/TCPServerConnection.h>
#include <Server/IServer.h>
#include <Server/SSH/SSHSession.h>
namespace DB
{
class SSHPtyHandler : public Poco::Net::TCPServerConnection
{
public:
explicit SSHPtyHandler
(
IServer & server_,
::ssh::SSHSession session_,
const Poco::Net::StreamSocket & socket,
unsigned int max_auth_attempts_,
unsigned int auth_timeout_seconds_,
unsigned int finish_timeout_seconds_,
unsigned int event_poll_interval_milliseconds_
);
void run() override;
private:
IServer & server;
Poco::Logger * log;
::ssh::SSHSession session;
unsigned int max_auth_attempts;
unsigned int auth_timeout_seconds;
unsigned int finish_timeout_seconds;
unsigned int event_poll_interval_milliseconds;
};
}
#endif

View File

@ -0,0 +1,120 @@
#pragma once
#include "config.h"
#if USE_SSH
#include <optional>
#include <Server/SSH/SSHPtyHandler.h>
#include <Server/TCPServer.h>
#include <Server/TCPServerConnectionFactory.h>
#include <Poco/Util/AbstractConfiguration.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>
#include <Server/IServer.h>
#include <Common/LibSSHLogger.h>
#include <Server/SSH/SSHBind.h>
#include <Server/SSH/SSHSession.h>
namespace Poco
{
class Logger;
}
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
class SSHPtyHandlerFactory : public TCPServerConnectionFactory
{
private:
IServer & server;
Poco::Logger * log;
::ssh::SSHBind bind;
unsigned int max_auth_attempts;
unsigned int auth_timeout_seconds;
unsigned int finish_timeout_seconds;
unsigned int event_poll_interval_milliseconds;
std::optional<int> read_write_timeout_seconds; // optional here, as libssh has its own defaults
std::optional<int> read_write_timeout_micro_seconds;
public:
explicit SSHPtyHandlerFactory(
IServer & server_, const Poco::Util::AbstractConfiguration & config)
: server(server_), log(&Poco::Logger::get("SSHHandlerFactory"))
{
LOG_INFO(log, "Initializing sshbind");
bind.disableDefaultConfig();
String prefix = "ssh.";
auto rsa_key = config.getString(prefix + "host_rsa_key", "");
auto ecdsa_key = config.getString(prefix + "host_ecdsa_key", "");
auto ed25519_key = config.getString(prefix + "host_ed25519_key", "");
max_auth_attempts = config.getUInt("max_auth_attempts", 4);
auth_timeout_seconds = config.getUInt("auth_timeout_seconds", 10);
finish_timeout_seconds = config.getUInt("finish_timeout_seconds", 5);
event_poll_interval_milliseconds = config.getUInt("event_poll_interval_milliseconds", 100);
if (config.has("read_write_timeout_seconds"))
{
read_write_timeout_seconds = config.getInt("read_write_timeout_seconds");
if (read_write_timeout_seconds.value() < 0)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Negative timeout specified");
}
if (config.has("read_write_timeout_micro_seconds"))
{
read_write_timeout_micro_seconds = config.getInt("read_write_timeout_micro_seconds");
if (read_write_timeout_micro_seconds.value() < 0)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Negative timeout specified");
}
if (event_poll_interval_milliseconds == 0)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Poll interval must be positive");
if (auth_timeout_seconds * 1000 < event_poll_interval_milliseconds)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Poll interval exceeds auth timeout");
if (finish_timeout_seconds * 1000 < event_poll_interval_milliseconds)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Poll interval exceeds finish timeout");
if (rsa_key.empty() && ecdsa_key.empty() && ed25519_key.empty())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Host key for ssh endpoint is not initialized");
if (!rsa_key.empty())
bind.setHostKey(rsa_key);
if (!ecdsa_key.empty())
bind.setHostKey(ecdsa_key);
if (!ed25519_key.empty())
bind.setHostKey(ed25519_key);
}
Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket, TCPServer &) override
{
LOG_TRACE(log, "TCP Request. Address: {}", socket.peerAddress().toString());
::ssh::libsshLogger::initialize();
::ssh::SSHSession session;
session.disableSocketOwning();
session.disableDefaultConfig();
if (read_write_timeout_seconds.has_value() || read_write_timeout_micro_seconds.has_value())
{
session.setTimeout(read_write_timeout_seconds.value_or(0), read_write_timeout_micro_seconds.value_or(0));
}
bind.acceptFd(session, socket.sockfd());
return new SSHPtyHandler(
server,
std::move(session),
socket,
max_auth_attempts,
auth_timeout_seconds,
finish_timeout_seconds,
event_poll_interval_milliseconds);
}
};
}
#endif

View File

@ -0,0 +1,124 @@
#include <Server/SSH/SSHSession.h>
#if USE_SSH
#include <fmt/format.h>
#include <Common/Exception.h>
#include <Common/clibssh.h>
namespace DB
{
namespace ErrorCodes
{
extern const int SSH_EXCEPTION;
}
}
namespace ssh
{
SSHSession::SSHSession() : session(ssh_new())
{
if (!session)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed to create ssh_session");
}
SSHSession::~SSHSession()
{
if (session)
ssh_free(session);
}
SSHSession::SSHSession(SSHSession && rhs) noexcept
{
*this = std::move(rhs);
}
SSHSession & SSHSession::operator=(SSHSession && rhs) noexcept
{
this->session = rhs.session;
rhs.session = nullptr;
return *this;
}
SSHSession::SessionPtr SSHSession::getInternalPtr() const
{
return session;
}
void SSHSession::connect()
{
int rc = ssh_connect(session);
if (rc != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed connecting in ssh session due due to {}", getError());
}
void SSHSession::disableDefaultConfig()
{
bool enable = false;
int rc = ssh_options_set(session, SSH_OPTIONS_PROCESS_CONFIG, &enable);
if (rc != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed disabling default config for ssh session due due to {}", getError());
}
void SSHSession::disableSocketOwning()
{
bool owns_socket = false;
int rc = ssh_options_set(session, SSH_OPTIONS_OWNS_SOCKET, &owns_socket);
if (rc != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed disabling socket owning for ssh session due to {}", getError());
}
void SSHSession::setPeerHost(const String & host)
{
int rc = ssh_options_set(session, SSH_OPTIONS_HOST, host.c_str());
if (rc != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed setting peer host option for ssh session due due to {}", getError());
}
void SSHSession::setFd(int fd)
{
int rc = ssh_options_set(session, SSH_OPTIONS_FD, &fd);
if (rc != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed setting fd option for ssh session due due to {}", getError());
}
void SSHSession::setTimeout(int timeout, int timeout_usec)
{
int rc = ssh_options_set(session, SSH_OPTIONS_TIMEOUT, &timeout);
if (rc != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed setting for ssh session due timeout option due to {}", getError());
rc |= ssh_options_set(session, SSH_OPTIONS_TIMEOUT_USEC, &timeout_usec);
if (rc != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed setting for ssh session due timeout_usec option due to {}", getError());
}
void SSHSession::handleKeyExchange()
{
int rc = ssh_handle_key_exchange(session);
if (rc != SSH_OK)
throw DB::Exception(DB::ErrorCodes::SSH_EXCEPTION, "Failed key exchange for ssh session due to {}", getError());
}
void SSHSession::disconnect()
{
ssh_disconnect(session);
}
String SSHSession::getError()
{
return String(ssh_get_error(session));
}
bool SSHSession::hasFinished()
{
return ssh_get_status(session) & (SSH_CLOSED | SSH_CLOSED_ERROR);
}
}
#endif

View File

@ -0,0 +1,60 @@
#pragma once
#include "config.h"
#if USE_SSH
#include <memory>
#include <base/types.h>
struct ssh_session_struct;
namespace ssh
{
// Wrapper around libssh's ssh_session
class SSHSession
{
public:
SSHSession();
~SSHSession();
SSHSession(SSHSession &&) noexcept;
SSHSession & operator=(SSHSession &&) noexcept;
SSHSession(const SSHSession &) = delete;
SSHSession & operator=(const SSHSession &) = delete;
using SessionPtr = ssh_session_struct *;
/// Get raw pointer from libssh to be able to pass it to other objects
SessionPtr getInternalPtr() const;
/// Disable reading default libssh configuration
void disableDefaultConfig();
/// Disable session from closing socket. Can be used when a socket is passed.
void disableSocketOwning();
/// Connect / disconnect
void connect();
void disconnect();
/// Configure session
void setPeerHost(const String & host);
// Pass ready socket to session
void setFd(int fd);
void setTimeout(int timeout, int timeout_usec);
void handleKeyExchange();
/// Error handling
String getError();
// Check that session was closed
bool hasFinished();
private:
SessionPtr session = nullptr;
};
}
#endif

View File

@ -47,6 +47,7 @@ bool ServerType::shouldStart(Type server_type, const std::string & server_custom
switch (current_type)
{
case Type::TCP:
case Type::TCP_SSH:
case Type::TCP_WITH_PROXY:
case Type::TCP_SECURE:
case Type::HTTP:
@ -110,6 +111,9 @@ bool ServerType::shouldStop(const std::string & port_name) const
else if (port_name == "tcp_port")
port_type = Type::TCP;
else if (port_name == "tcp_ssh_port")
port_type = Type::TCP_SSH;
else if (port_name == "tcp_with_proxy_port")
port_type = Type::TCP_WITH_PROXY;

View File

@ -13,6 +13,7 @@ public:
{
TCP_WITH_PROXY,
TCP_SECURE,
TCP_SSH,
TCP,
HTTP,
HTTPS,

View File

@ -37,6 +37,8 @@ public:
UInt16 portNumber() const { return port_number; }
const Poco::Net::ServerSocket& getSocket() { return socket; }
private:
TCPServerConnectionFactory::Ptr factory;
Poco::Net::ServerSocket socket;

View File

View File

@ -0,0 +1,6 @@
<clickhouse>
<tcp_ssh_port>9022</tcp_ssh_port>
<ssh>
<host_rsa_key>/etc/clickhouse-server/config.d/ssh_host_rsa_key</host_rsa_key>
</ssh>
</clickhouse>

View File

@ -0,0 +1,12 @@
<clickhouse>
<users>
<lucy>
<ssh_keys>
<ssh_key>
<type>ssh-ed25519</type>
<base64_key>AAAAC3NzaC1lZDI1NTE5AAAAIA5p06mOZGpz7ePU57OmQ08v3U+CpWa2u1f9/V/yoZ1n</base64_key>
</ssh_key>
</ssh_keys>
</lucy>
</users>
</clickhouse>

View File

@ -0,0 +1,7 @@
-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
QyNTUxOQAAACAOadOpjmRqc+3j1OezpkNPL91PgqVmtrtX/f1f8qGdZwAAAKDbJgdb2yYH
WwAAAAtzc2gtZWQyNTUxOQAAACAOadOpjmRqc+3j1OezpkNPL91PgqVmtrtX/f1f8qGdZw
AAAECcs+RQOSe++AHZNDgAPwkfwd6t6e6HLN7c0ZDFXAGJ0g5p06mOZGpz7ePU57OmQ08v
3U+CpWa2u1f9/V/yoZ1nAAAAFnVidW50dUBpcC0xMC0xMC0xMC0xMDQBAgMEBQYH
-----END OPENSSH PRIVATE KEY-----

View File

@ -0,0 +1 @@
ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIA5p06mOZGpz7ePU57OmQ08v3U+CpWa2u1f9/V/yoZ1n test

View File

@ -0,0 +1,49 @@
-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAACFwAAAAdzc2gtcn
NhAAAAAwEAAQAAAgEArzrm/JkXzklKKcz1IKY6ivokvWmvpJRx3TmJP/hA/hGQ+YZGen7A
2RD+enL+pwnqjknGBMFfyaGJWFfOHVlYxZ4BTkpUXIayVm1hiTnfoNMvXN+dKbf8K/u95U
VLPhW08e1+ymB+ZBal7lSZL2gLcyDT4Riy8Yu3kDuxo88GRp3hXU+Ygn7FMTxYM+yCj8rx
9nIqwYA721uXTmADPYZLh9j+SJI4nJN0eiQ/9IrMRJDTeEPi3M0pXLXIdC13j4O9+xY8Rw
PVOoQl+5ntEWgfWzu1LLf6ODqZKr1V/yMuHKXCKIqbwmb/pfwaec0yaDDFQT48DsI14+pB
O3fL7cPzqBCEeeDugXyiyb83zNy9qPLAdd0w16tkYrNYp9C/LRZP+M6GvFPX9PJ/130Yf7
q3GeRJAv3bL3MZ9pbqFJpzzrVgAeqRXF/C73zGENaffKLGFJ5CyxZKI+VfhtIbrp23PDDz
JFTMBQuy9omdevfqRXuDPEhuwFvpTkQumzX3OHNIXKiyoAwDd4wXRNWWqj7eEPNUATeY/k
lbHfRGjXFNNhDCptrekT6mCYGYo+BzEoKGui5onoS4yx1hMrykka1SzzAKlljqSY/ezlTg
5Z+xcxD9XWYZXFannB7LJZvGAJ1QWeoH48FGBNKOTjKlOiNRL2JuvlhPVWkEk/3AdcPdGn
0AAAdQDlyKDQ5cig0AAAAHc3NoLXJzYQAAAgEArzrm/JkXzklKKcz1IKY6ivokvWmvpJRx
3TmJP/hA/hGQ+YZGen7A2RD+enL+pwnqjknGBMFfyaGJWFfOHVlYxZ4BTkpUXIayVm1hiT
nfoNMvXN+dKbf8K/u95UVLPhW08e1+ymB+ZBal7lSZL2gLcyDT4Riy8Yu3kDuxo88GRp3h
XU+Ygn7FMTxYM+yCj8rx9nIqwYA721uXTmADPYZLh9j+SJI4nJN0eiQ/9IrMRJDTeEPi3M
0pXLXIdC13j4O9+xY8RwPVOoQl+5ntEWgfWzu1LLf6ODqZKr1V/yMuHKXCKIqbwmb/pfwa
ec0yaDDFQT48DsI14+pBO3fL7cPzqBCEeeDugXyiyb83zNy9qPLAdd0w16tkYrNYp9C/LR
ZP+M6GvFPX9PJ/130Yf7q3GeRJAv3bL3MZ9pbqFJpzzrVgAeqRXF/C73zGENaffKLGFJ5C
yxZKI+VfhtIbrp23PDDzJFTMBQuy9omdevfqRXuDPEhuwFvpTkQumzX3OHNIXKiyoAwDd4
wXRNWWqj7eEPNUATeY/klbHfRGjXFNNhDCptrekT6mCYGYo+BzEoKGui5onoS4yx1hMryk
ka1SzzAKlljqSY/ezlTg5Z+xcxD9XWYZXFannB7LJZvGAJ1QWeoH48FGBNKOTjKlOiNRL2
JuvlhPVWkEk/3AdcPdGn0AAAADAQABAAACAAK4gt9zzu25idqH5GU0qQSY66knYl0X4VR+
/1V24kOaTLP6if4pUgDls8gxkh76/TtdbFvY4GzZ1XsnizIE1Zai4gd0qsZA7wDwuJ2o77
vCZlioplUAZydX5eD5K9CYbcxPLGVE4lIZjtAvp+YN9of+LzdGBRG1A4Y+SbNZg9RgFkHO
g8gyLrP4bg9HS58B50pH9KtdZy9WM/gYEFMJ1zvUlIeBFPgg9cSGU04vvlYD6Na8S7KgX3
ZtJKlJ36kW/8/9ENIykT3dcmlPB7LKjBEN5Bx2sk83TkyEa4SviNZBcPzbAyLRuSU7IJOq
uFviHwL75auHli2fL96DHvth8eUKD1UeissdnvVRcV/FYE2GXabBLDr4YIn9h4bLJ9lI/d
eDDPFagASWJisk1PfGwpEdpw6tZrYEerOkQTIwBRShB4xpYcCJ4eACafL64U9Az2jFdxK3
gaLUr1Bkhus48Pp4YIGrhClFz3yIPTkS9XmIBlnrxa64cNH0A+i9yaAe6esvTPRyDRT3MB
oPwZQczzuwQaODEVFz5a+L4U2pg4vBK2StLm0166R1lz0czXeFksx97F+8CA/gvPooVtI4
6SZpqJrsQ3ajjmQ/OYDWn8owzb2o463xy2MNsy94cMb7CkO+OU8tV8ytW0hzigFvRRtB3P
YoOygUHIYE7f35hUMjAAABACI1GE2wzgqoqKmtTQtSY4LfX86Y4j2W+7tscirzPG0XcGrh
fl2XHht21n7C2LSU3AoXgT9uyzBrn1oCrV+bUusi7wnRQ2DStLlIIegz+lTbZfa+CV3km4
mJh/BCeunRq/XrJXN2XMNPcVcR4ProqVdTBYXW2gOJYC+wkrawuKZPCDXS3BOKXlPJI3/m
sggxzHIZL/4ndrez/e4QR8GTpRzT6gM4PbgtEePL7XbPjWsD4BrOWM4wZmf7T2z4ohc/bj
4hnJiAokhKoir/4wKZGJ7OA1b7rW/mFFpIMG+1oVqQZuJwa9HkbStLjQCJu5xywlUVJtqK
BXy4kV4+SbsId1kAAAEBAN5JsuqpJnQN22S3iq4KIHIL1/vJCYm/xweVYAhgyJ1DLfPH12
EEwUT72apAG4URlarL8hcgJYKZ73JZuzVH6nW6xLS8U0NeIblIh74qwV008i8av9qprCfV
LexyRx0Zf5g85GA72kZTKTCD0LdzjjMHGoWcx34aAaPSKvQdsd8anttRL34LefG9VQWQwL
IQsZatwfFVM9O7gtN++0Dqd8kAmlHEQcMvpmbhPfpsr+MIG6sO4MRqGJ+8ior2YCdK7cUX
59PlVv6A1B7Fh3WVKfYT+s+yejGMua1aX3auLX3kEHFG0ckjK1i2r5wEbT6Tvjgz7sGFOx
+HfwvkCWz2Bk8AAAEBAMnOMJr4RH8xHI/NIvYB8+0m9inkVCKiDRVIKJQPeOLgc/5jWTTN
uErVIsN+gQXxIhdG3/0tWu/sVIASfGAkbYuvSzRT/bWfTe5Oc6gNoBEC5vGzHwNaJVZnPe
k40ON1mXtVlJ167Ku3xx5IekyIKnWxAnk/00tX7Ig5vDspFIsQuXmUP3Zs29ZlfkPi7ebO
884o15meC+104mbFLgNzKdm+XrQuP+is3sQ7CzB45wSRNwVlBJSWq+qIt0/LyPsX9fShr3
rRKl3lyUcrgRexwYXI3bnP/LwnvUROToi6/30NE7rcELXZ29E+R4H45GOEGmt8wIU3bJyX
T5tk8CPnK3MAAAAWdWJ1bnR1QGlwLTEwLTEwLTEwLTEwNAECAwQF
-----END OPENSSH PRIVATE KEY-----

View File

@ -0,0 +1 @@
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQCvOub8mRfOSUopzPUgpjqK+iS9aa+klHHdOYk/+ED+EZD5hkZ6fsDZEP56cv6nCeqOScYEwV/JoYlYV84dWVjFngFOSlRchrJWbWGJOd+g0y9c350pt/wr+73lRUs+FbTx7X7KYH5kFqXuVJkvaAtzINPhGLLxi7eQO7GjzwZGneFdT5iCfsUxPFgz7IKPyvH2cirBgDvbW5dOYAM9hkuH2P5Ikjick3R6JD/0isxEkNN4Q+LczSlctch0LXePg737FjxHA9U6hCX7me0RaB9bO7Ust/o4OpkqvVX/Iy4cpcIoipvCZv+l/Bp5zTJoMMVBPjwOwjXj6kE7d8vtw/OoEIR54O6BfKLJvzfM3L2o8sB13TDXq2Ris1in0L8tFk/4zoa8U9f08n/XfRh/urcZ5EkC/dsvcxn2luoUmnPOtWAB6pFcX8LvfMYQ1p98osYUnkLLFkoj5V+G0huunbc8MPMkVMwFC7L2iZ169+pFe4M8SG7AW+lORC6bNfc4c0hcqLKgDAN3jBdE1ZaqPt4Q81QBN5j+SVsd9EaNcU02EMKm2t6RPqYJgZij4HMSgoa6LmiehLjLHWEyvKSRrVLPMAqWWOpJj97OVODln7FzEP1dZhlcVqecHsslm8YAnVBZ6gfjwUYE0o5OMqU6I1EvYm6+WE9VaQST/cB1w90afQ== test

View File

@ -0,0 +1,45 @@
import subprocess
import pytest
import os
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__)
instance = cluster.add_instance(
"node",
main_configs=["configs/server.xml", "keys/ssh_host_rsa_key"],
user_configs=["configs/users.xml"],
)
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
@pytest.fixture(scope="module", autouse=True)
def started_cluster():
try:
cluster.start()
yield cluster
finally:
cluster.shutdown()
def test_simple_query_with_openssh_client():
ssh_command = (
"ssh -o StrictHostKeyChecking"
+ f"=no lucy@{instance.ip_address} -p 9022"
+ f' -i {SCRIPT_DIR}/keys/lucy_ed25519 "select 1"'
)
completed_process = subprocess.run(
ssh_command,
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
expected = instance.query("select 1")
output = completed_process.stdout
assert output.replace("\n\x00", "\n") == expected

View File

@ -328,7 +328,10 @@ std_cerr_cout_excludes=(
src/Processors/IProcessor.cpp
src/Client/ClientApplicationBase.cpp
src/Client/ClientBase.cpp
src/Common/ProgressIndication.h
src/Client/LineReader.h
src/Client/LineReader.cpp
src/Client/ReplxxLineReader.h
src/Client/QueryFuzzer.cpp
src/Client/Suggest.cpp
src/Client/ClientBase.h
@ -348,7 +351,7 @@ sources_with_std_cerr_cout=( $(
) )
# Exclude comments
for src in "${sources_with_std_cerr_cout[@]}"; do
# suppress stderr, since it may contain warning for #pargma once in headers
# suppress stderr, since it may contain warning for #pragma once in headers
if gcc -fpreprocessed -dD -E "$src" 2>/dev/null | grep -F -q -e std::cerr -e std::cout; then
echo "$src: uses std::cerr/std::cout"
fi