mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-24 00:22:29 +00:00
Merge branch 'master' of github.com:yandex/ClickHouse
This commit is contained in:
commit
0eade98688
@ -52,12 +52,12 @@ IncludeCategories:
|
||||
ReflowComments: false
|
||||
AlignEscapedNewlinesLeft: false
|
||||
AlignEscapedNewlines: DontAlign
|
||||
AlignTrailingComments: true
|
||||
|
||||
# Not changed:
|
||||
AccessModifierOffset: -4
|
||||
AlignConsecutiveAssignments: false
|
||||
AlignOperands: false
|
||||
AlignTrailingComments: false
|
||||
AllowAllParametersOfDeclarationOnNextLine: true
|
||||
AllowShortBlocksOnASingleLine: false
|
||||
AllowShortCaseLabelsOnASingleLine: false
|
||||
|
@ -11,8 +11,3 @@ ClickHouse is an open-source column-oriented database management system that all
|
||||
* [Blog](https://clickhouse.yandex/blog/en/) contains various ClickHouse-related articles, as well as announces and reports about events.
|
||||
* [Contacts](https://clickhouse.yandex/#contacts) can help to get your questions answered if there are any.
|
||||
* You can also [fill this form](https://forms.yandex.com/surveys/meet-yandex-clickhouse-team/) to meet Yandex ClickHouse team in person.
|
||||
|
||||
## Upcoming Events
|
||||
|
||||
* [ClickHouse Meetup in San Francisco](https://www.eventbrite.com/e/clickhouse-february-meetup-registration-88496227599) on February 5.
|
||||
* [ClickHouse Meetup in New York](https://www.meetup.com/Uber-Engineering-Events-New-York/events/268328663/) on February 11.
|
||||
|
@ -1,4 +1,4 @@
|
||||
if (NOT APPLE AND NOT ARCH_32)
|
||||
if (NOT ARCH_32)
|
||||
option (USE_INTERNAL_LIBGSASL_LIBRARY "Set to FALSE to use system libgsasl library instead of bundled" ${NOT_UNBUNDLED})
|
||||
endif ()
|
||||
|
||||
@ -16,7 +16,7 @@ if (NOT USE_INTERNAL_LIBGSASL_LIBRARY)
|
||||
endif ()
|
||||
|
||||
if (LIBGSASL_LIBRARY AND LIBGSASL_INCLUDE_DIR)
|
||||
elseif (NOT MISSING_INTERNAL_LIBGSASL_LIBRARY AND NOT APPLE AND NOT ARCH_32)
|
||||
elseif (NOT MISSING_INTERNAL_LIBGSASL_LIBRARY AND NOT ARCH_32)
|
||||
set (LIBGSASL_INCLUDE_DIR ${ClickHouse_SOURCE_DIR}/contrib/libgsasl/src ${ClickHouse_SOURCE_DIR}/contrib/libgsasl/linux_x86_64/include)
|
||||
set (USE_INTERNAL_LIBGSASL_LIBRARY 1)
|
||||
set (LIBGSASL_LIBRARY libgsasl)
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Freebsd: contrib/cppkafka/include/cppkafka/detail/endianness.h:53:23: error: 'betoh16' was not declared in this scope
|
||||
if (NOT ARCH_ARM AND NOT ARCH_32 AND NOT APPLE AND NOT OS_FREEBSD AND OPENSSL_FOUND)
|
||||
if (NOT ARCH_ARM AND NOT ARCH_32 AND NOT OS_FREEBSD AND OPENSSL_FOUND)
|
||||
option (ENABLE_RDKAFKA "Enable kafka" ${ENABLE_LIBRARIES})
|
||||
endif ()
|
||||
|
||||
@ -10,7 +10,7 @@ endif ()
|
||||
|
||||
if (ENABLE_RDKAFKA)
|
||||
|
||||
if (OS_LINUX AND NOT ARCH_ARM AND USE_LIBGSASL)
|
||||
if (NOT ARCH_ARM AND USE_LIBGSASL)
|
||||
option (USE_INTERNAL_RDKAFKA_LIBRARY "Set to FALSE to use system librdkafka instead of the bundled" ${NOT_UNBUNDLED})
|
||||
endif ()
|
||||
|
||||
|
@ -146,3 +146,5 @@ target_compile_definitions(curl PRIVATE HAVE_CONFIG_H BUILDING_LIBCURL CURL_HIDD
|
||||
target_include_directories(curl PUBLIC ${CURL_DIR}/include ${CURL_DIR}/lib .)
|
||||
|
||||
target_compile_definitions(curl PRIVATE OS="${CMAKE_SYSTEM_NAME}")
|
||||
|
||||
target_link_libraries(curl PRIVATE ssl)
|
||||
|
@ -1,3 +1,4 @@
|
||||
#define CURL_CA_BUNDLE "/etc/ssl/certs/ca-certificates.crt"
|
||||
#define CURL_DISABLE_FTP
|
||||
#define CURL_DISABLE_TFTP
|
||||
#define CURL_DISABLE_LDAP
|
||||
@ -9,9 +10,14 @@
|
||||
#define SIZEOF_CURL_OFF_T 8
|
||||
#define SIZEOF_SIZE_T 8
|
||||
|
||||
#define HAVE_ALARM
|
||||
#define HAVE_FCNTL_O_NONBLOCK
|
||||
#define HAVE_GETADDRINFO
|
||||
#define HAVE_LONGLONG
|
||||
#define HAVE_POLL_FINE
|
||||
#define HAVE_SIGACTION
|
||||
#define HAVE_SIGNAL
|
||||
#define HAVE_SIGSETJMP
|
||||
#define HAVE_SOCKET
|
||||
#define HAVE_STRUCT_TIMEVAL
|
||||
|
||||
@ -34,5 +40,11 @@
|
||||
#define HAVE_ERRNO_H
|
||||
#define HAVE_FCNTL_H
|
||||
#define HAVE_NETDB_H
|
||||
#define HAVE_NETINET_IN_H
|
||||
#define HAVE_SETJMP_H
|
||||
#define HAVE_SYS_STAT_H
|
||||
#define HAVE_UNISTD_H
|
||||
|
||||
#define ENABLE_IPV6
|
||||
#define USE_OPENSSL
|
||||
#define USE_THREADS_POSIX
|
||||
|
2
contrib/libgsasl
vendored
2
contrib/libgsasl
vendored
@ -1 +1 @@
|
||||
Subproject commit 3b8948a4042e34fb00b4fb987535dc9e02e39040
|
||||
Subproject commit 42ef20687042637252e64df1934b6d47771486d1
|
2
contrib/librdkafka
vendored
2
contrib/librdkafka
vendored
@ -1 +1 @@
|
||||
Subproject commit 6160ec275a5bb0a4088ede3c5f2afde638bbef65
|
||||
Subproject commit 4ffe54b4f59ee5ae3767f9f25dc14651a3384d62
|
@ -23,6 +23,8 @@ set(SRCS
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_lz4.c
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_metadata.c
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_metadata_cache.c
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_mock.c
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_mock_handlers.c
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_msg.c
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_msgset_reader.c
|
||||
${RDKAFKA_SOURCE_DIR}/rdkafka_msgset_writer.c
|
||||
|
@ -75,8 +75,18 @@
|
||||
#define HAVE_STRNDUP 1
|
||||
// strerror_r
|
||||
#define HAVE_STRERROR_R 1
|
||||
|
||||
#ifdef __APPLE__
|
||||
// pthread_setname_np
|
||||
#define HAVE_PTHREAD_SETNAME_DARWIN 1
|
||||
#if (__ENVIRONMENT_MAC_OS_X_VERSION_MIN_REQUIRED__ <= 101400)
|
||||
#define _TTHREAD_EMULATE_TIMESPEC_GET_
|
||||
#endif
|
||||
|
||||
#else
|
||||
// pthread_setname_gnu
|
||||
#define HAVE_PTHREAD_SETNAME_GNU 1
|
||||
#endif
|
||||
// python
|
||||
//#define HAVE_PYTHON 1
|
||||
// disable C11 threads for compatibility with old libc
|
||||
|
2
contrib/poco
vendored
2
contrib/poco
vendored
@ -1 +1 @@
|
||||
Subproject commit d478f62bd93c9cd14eb343756ef73a4ae622ddf5
|
||||
Subproject commit d805cf5ca4cf8bdc642261cfcbe7a0a241cb7298
|
@ -575,6 +575,7 @@ void HTTPHandler::processQuery(
|
||||
try
|
||||
{
|
||||
char b;
|
||||
//FIXME looks like MSG_DONTWAIT is useless because of POCO_BROKEN_TIMEOUTS
|
||||
int status = socket.receiveBytes(&b, 1, MSG_DONTWAIT | MSG_PEEK);
|
||||
if (status == 0)
|
||||
context.killCurrentQuery();
|
||||
|
@ -218,7 +218,7 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl
|
||||
try
|
||||
{
|
||||
// For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used.
|
||||
auto user = connection_context.getUser(user_name);
|
||||
auto user = connection_context.getAccessControlManager().getUser(user_name);
|
||||
const DB::Authentication::Type user_auth_type = user->authentication.getType();
|
||||
if (user_auth_type != DB::Authentication::DOUBLE_SHA1_PASSWORD && user_auth_type != DB::Authentication::PLAINTEXT_PASSWORD && user_auth_type != DB::Authentication::NO_PASSWORD)
|
||||
{
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <Common/config.h>
|
||||
#include <Poco/Net/TCPServerConnection.h>
|
||||
#include <Common/getFQDNOrHostName.h>
|
||||
#include <Common/CurrentMetrics.h>
|
||||
#include <Core/MySQLProtocol.h>
|
||||
#include "IServer.h"
|
||||
|
||||
@ -9,6 +10,11 @@
|
||||
#include <Poco/Net/SecureStreamSocket.h>
|
||||
#endif
|
||||
|
||||
namespace CurrentMetrics
|
||||
{
|
||||
extern const Metric MySQLConnection;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
/// Handler for MySQL wire protocol connections. Allows to connect to ClickHouse using MySQL client.
|
||||
@ -20,6 +26,8 @@ public:
|
||||
void run() final;
|
||||
|
||||
private:
|
||||
CurrentMetrics::Increment metric_increment{CurrentMetrics::MySQLConnection};
|
||||
|
||||
/// Enables SSL, if client requested.
|
||||
void finishHandshake(MySQLProtocol::HandshakeResponse &);
|
||||
|
||||
|
@ -900,6 +900,10 @@ void TCPHandler::receiveQuery()
|
||||
client_info.initial_query_id = client_info.current_query_id;
|
||||
client_info.initial_address = client_info.current_address;
|
||||
}
|
||||
else
|
||||
{
|
||||
query_context->switchRowPolicy();
|
||||
}
|
||||
}
|
||||
|
||||
/// Per query settings.
|
||||
|
@ -185,7 +185,7 @@
|
||||
<mlock_executable>false</mlock_executable>
|
||||
|
||||
<!-- Configuration of clusters that could be used in Distributed tables.
|
||||
https://clickhouse.yandex/docs/en/table_engines/distributed/
|
||||
https://clickhouse.tech/docs/en/operations/table_engines/distributed/
|
||||
-->
|
||||
<remote_servers incl="clickhouse_remote_servers" >
|
||||
<!-- Test only shard config for testing distributed storage -->
|
||||
|
@ -35,45 +35,49 @@ AccessControlManager::~AccessControlManager()
|
||||
}
|
||||
|
||||
|
||||
UserPtr AccessControlManager::getUser(const String & user_name) const
|
||||
UserPtr AccessControlManager::getUser(
|
||||
const String & user_name, std::function<void(const UserPtr &)> on_change, ext::scope_guard * subscription) const
|
||||
{
|
||||
return getUser(user_name, {}, nullptr);
|
||||
return getUser(getID<User>(user_name), std::move(on_change), subscription);
|
||||
}
|
||||
|
||||
|
||||
UserPtr AccessControlManager::getUser(
|
||||
const String & user_name, const std::function<void(const UserPtr &)> & on_change, ext::scope_guard * subscription) const
|
||||
const UUID & user_id, std::function<void(const UserPtr &)> on_change, ext::scope_guard * subscription) const
|
||||
{
|
||||
UUID id = getID<User>(user_name);
|
||||
if (on_change && subscription)
|
||||
{
|
||||
*subscription = subscribeForChanges(id, [on_change](const UUID &, const AccessEntityPtr & user)
|
||||
*subscription = subscribeForChanges(user_id, [on_change](const UUID &, const AccessEntityPtr & user)
|
||||
{
|
||||
if (user)
|
||||
on_change(typeid_cast<UserPtr>(user));
|
||||
});
|
||||
}
|
||||
return read<User>(id);
|
||||
return read<User>(user_id);
|
||||
}
|
||||
|
||||
|
||||
UserPtr AccessControlManager::authorizeAndGetUser(
|
||||
const String & user_name,
|
||||
const String & password,
|
||||
const Poco::Net::IPAddress & address) const
|
||||
{
|
||||
return authorizeAndGetUser(user_name, password, address, {}, nullptr);
|
||||
}
|
||||
|
||||
UserPtr AccessControlManager::authorizeAndGetUser(
|
||||
const String & user_name,
|
||||
const String & password,
|
||||
const Poco::Net::IPAddress & address,
|
||||
const std::function<void(const UserPtr &)> & on_change,
|
||||
std::function<void(const UserPtr &)> on_change,
|
||||
ext::scope_guard * subscription) const
|
||||
{
|
||||
auto user = getUser(user_name, on_change, subscription);
|
||||
user->allowed_client_hosts.checkContains(address, user_name);
|
||||
user->authentication.checkPassword(password, user_name);
|
||||
return authorizeAndGetUser(getID<User>(user_name), password, address, std::move(on_change), subscription);
|
||||
}
|
||||
|
||||
|
||||
UserPtr AccessControlManager::authorizeAndGetUser(
|
||||
const UUID & user_id,
|
||||
const String & password,
|
||||
const Poco::Net::IPAddress & address,
|
||||
std::function<void(const UserPtr &)> on_change,
|
||||
ext::scope_guard * subscription) const
|
||||
{
|
||||
auto user = getUser(user_id, on_change, subscription);
|
||||
user->allowed_client_hosts.checkContains(address, user->getName());
|
||||
user->authentication.checkPassword(password, user->getName());
|
||||
return user;
|
||||
}
|
||||
|
||||
@ -85,9 +89,9 @@ void AccessControlManager::loadFromConfig(const Poco::Util::AbstractConfiguratio
|
||||
}
|
||||
|
||||
|
||||
std::shared_ptr<const AccessRightsContext> AccessControlManager::getAccessRightsContext(const ClientInfo & client_info, const AccessRights & granted_to_user, const Settings & settings, const String & current_database)
|
||||
std::shared_ptr<const AccessRightsContext> AccessControlManager::getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database)
|
||||
{
|
||||
return std::make_shared<AccessRightsContext>(client_info, granted_to_user, settings, current_database);
|
||||
return std::make_shared<AccessRightsContext>(user, client_info, settings, current_database);
|
||||
}
|
||||
|
||||
|
||||
|
@ -42,12 +42,12 @@ public:
|
||||
|
||||
void loadFromConfig(const Poco::Util::AbstractConfiguration & users_config);
|
||||
|
||||
UserPtr getUser(const String & user_name) const;
|
||||
UserPtr getUser(const String & user_name, const std::function<void(const UserPtr &)> & on_change, ext::scope_guard * subscription) const;
|
||||
UserPtr authorizeAndGetUser(const String & user_name, const String & password, const Poco::Net::IPAddress & address) const;
|
||||
UserPtr authorizeAndGetUser(const String & user_name, const String & password, const Poco::Net::IPAddress & address, const std::function<void(const UserPtr &)> & on_change, ext::scope_guard * subscription) const;
|
||||
UserPtr getUser(const String & user_name, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const;
|
||||
UserPtr getUser(const UUID & user_id, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const;
|
||||
UserPtr authorizeAndGetUser(const String & user_name, const String & password, const Poco::Net::IPAddress & address, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const;
|
||||
UserPtr authorizeAndGetUser(const UUID & user_id, const String & password, const Poco::Net::IPAddress & address, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const;
|
||||
|
||||
std::shared_ptr<const AccessRightsContext> getAccessRightsContext(const ClientInfo & client_info, const AccessRights & granted_to_user, const Settings & settings, const String & current_database);
|
||||
std::shared_ptr<const AccessRightsContext> getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database);
|
||||
|
||||
std::shared_ptr<QuotaContext>
|
||||
createQuotaContext(const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key);
|
||||
|
@ -1,9 +1,9 @@
|
||||
#include <Access/AccessRights.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <common/logger_useful.h>
|
||||
#include <boost/range/adaptor/map.hpp>
|
||||
#include <unordered_map>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
@ -73,6 +73,7 @@ public:
|
||||
inherited_access = src.inherited_access;
|
||||
explicit_grants = src.explicit_grants;
|
||||
partial_revokes = src.partial_revokes;
|
||||
raw_access = src.raw_access;
|
||||
access = src.access;
|
||||
min_access = src.min_access;
|
||||
max_access = src.max_access;
|
||||
@ -114,8 +115,12 @@ public:
|
||||
access_to_grant = grantable;
|
||||
}
|
||||
|
||||
explicit_grants |= access_to_grant - partial_revokes;
|
||||
partial_revokes -= access_to_grant;
|
||||
AccessFlags new_explicit_grants = access_to_grant - partial_revokes;
|
||||
if (level == TABLE_LEVEL)
|
||||
removeExplicitGrantsRec(new_explicit_grants);
|
||||
removePartialRevokesRec(access_to_grant);
|
||||
explicit_grants |= new_explicit_grants;
|
||||
|
||||
calculateAllAccessRec(helper);
|
||||
}
|
||||
|
||||
@ -147,16 +152,27 @@ public:
|
||||
{
|
||||
if constexpr (mode == NORMAL_REVOKE_MODE)
|
||||
{
|
||||
explicit_grants -= access_to_revoke;
|
||||
if (level == TABLE_LEVEL)
|
||||
removeExplicitGrantsRec(access_to_revoke);
|
||||
else
|
||||
removeExplicitGrants(access_to_revoke);
|
||||
}
|
||||
else if constexpr (mode == PARTIAL_REVOKE_MODE)
|
||||
{
|
||||
partial_revokes |= access_to_revoke - explicit_grants;
|
||||
explicit_grants -= access_to_revoke;
|
||||
AccessFlags new_partial_revokes = access_to_revoke - explicit_grants;
|
||||
if (level == TABLE_LEVEL)
|
||||
removeExplicitGrantsRec(access_to_revoke);
|
||||
else
|
||||
removeExplicitGrants(access_to_revoke);
|
||||
removePartialRevokesRec(new_partial_revokes);
|
||||
partial_revokes |= new_partial_revokes;
|
||||
}
|
||||
else /// mode == FULL_REVOKE_MODE
|
||||
{
|
||||
fullRevokeRec(access_to_revoke);
|
||||
AccessFlags new_partial_revokes = access_to_revoke - explicit_grants;
|
||||
removeExplicitGrantsRec(access_to_revoke);
|
||||
removePartialRevokesRec(new_partial_revokes);
|
||||
partial_revokes |= new_partial_revokes;
|
||||
}
|
||||
calculateAllAccessRec(helper);
|
||||
}
|
||||
@ -272,6 +288,24 @@ public:
|
||||
calculateAllAccessRec(helper);
|
||||
}
|
||||
|
||||
void traceTree(Poco::Logger * log) const
|
||||
{
|
||||
LOG_TRACE(log, "Tree(" << level << "): name=" << (node_name ? *node_name : "NULL")
|
||||
<< ", explicit_grants=" << explicit_grants.toString()
|
||||
<< ", partial_revokes=" << partial_revokes.toString()
|
||||
<< ", inherited_access=" << inherited_access.toString()
|
||||
<< ", raw_access=" << raw_access.toString()
|
||||
<< ", access=" << access.toString()
|
||||
<< ", min_access=" << min_access.toString()
|
||||
<< ", max_access=" << max_access.toString()
|
||||
<< ", num_children=" << (children ? children->size() : 0));
|
||||
if (children)
|
||||
{
|
||||
for (auto & child : *children | boost::adaptors::map_values)
|
||||
child.traceTree(log);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Node * tryGetChild(const std::string_view & name)
|
||||
{
|
||||
@ -371,14 +405,28 @@ private:
|
||||
calculateMinAndMaxAccess();
|
||||
}
|
||||
|
||||
void fullRevokeRec(const AccessFlags & access_to_revoke)
|
||||
void removeExplicitGrants(const AccessFlags & change)
|
||||
{
|
||||
explicit_grants -= access_to_revoke;
|
||||
partial_revokes |= access_to_revoke;
|
||||
explicit_grants -= change;
|
||||
}
|
||||
|
||||
void removeExplicitGrantsRec(const AccessFlags & change)
|
||||
{
|
||||
removeExplicitGrants(change);
|
||||
if (children)
|
||||
{
|
||||
for (auto & child : *children | boost::adaptors::map_values)
|
||||
child.fullRevokeRec(access_to_revoke);
|
||||
child.removeExplicitGrantsRec(change);
|
||||
}
|
||||
}
|
||||
|
||||
void removePartialRevokesRec(const AccessFlags & change)
|
||||
{
|
||||
partial_revokes -= change;
|
||||
if (children)
|
||||
{
|
||||
for (auto & child : *children | boost::adaptors::map_values)
|
||||
child.removePartialRevokesRec(change);
|
||||
}
|
||||
}
|
||||
|
||||
@ -726,4 +774,13 @@ void AccessRights::merge(const AccessRights & other)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void AccessRights::traceTree() const
|
||||
{
|
||||
auto * log = &Poco::Logger::get("AccessRights");
|
||||
if (root)
|
||||
root->traceTree(log);
|
||||
else
|
||||
LOG_TRACE(log, "Tree: NULL");
|
||||
}
|
||||
}
|
||||
|
@ -130,6 +130,8 @@ private:
|
||||
template <typename... Args>
|
||||
AccessFlags getAccessImpl(const Args &... args) const;
|
||||
|
||||
void traceTree() const;
|
||||
|
||||
struct Node;
|
||||
std::unique_ptr<Node> root;
|
||||
};
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <Access/AccessRightsContext.h>
|
||||
#include <Access/User.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/quoteString.h>
|
||||
#include <Core/Settings.h>
|
||||
@ -88,24 +89,23 @@ AccessRightsContext::AccessRightsContext()
|
||||
}
|
||||
|
||||
|
||||
AccessRightsContext::AccessRightsContext(const ClientInfo & client_info_, const AccessRights & granted_to_user_, const Settings & settings, const String & current_database_)
|
||||
: user_name(client_info_.current_user)
|
||||
, granted_to_user(granted_to_user_)
|
||||
AccessRightsContext::AccessRightsContext(const UserPtr & user_, const ClientInfo & client_info_, const Settings & settings, const String & current_database_)
|
||||
: user(user_)
|
||||
, readonly(settings.readonly)
|
||||
, allow_ddl(settings.allow_ddl)
|
||||
, allow_introspection(settings.allow_introspection_functions)
|
||||
, current_database(current_database_)
|
||||
, interface(client_info_.interface)
|
||||
, http_method(client_info_.http_method)
|
||||
, trace_log(&Poco::Logger::get("AccessRightsContext (" + user_name + ")"))
|
||||
, trace_log(&Poco::Logger::get("AccessRightsContext (" + user_->getName() + ")"))
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
template <int mode, typename... Args>
|
||||
template <int mode, bool grant_option, typename... Args>
|
||||
bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const
|
||||
{
|
||||
auto result_access = calculateResultAccess();
|
||||
auto result_access = calculateResultAccess(grant_option);
|
||||
bool is_granted = result_access->isGranted(access, args...);
|
||||
|
||||
if (trace_log)
|
||||
@ -126,12 +126,21 @@ bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & acc
|
||||
auto show_error = [&](const String & msg, [[maybe_unused]] int error_code)
|
||||
{
|
||||
if constexpr (mode == THROW_IF_ACCESS_DENIED)
|
||||
throw Exception(msg, error_code);
|
||||
throw Exception(user->getName() + ": " + msg, error_code);
|
||||
else if constexpr (mode == LOG_WARNING_IF_ACCESS_DENIED)
|
||||
LOG_WARNING(log_, msg + formatSkippedMessage(args...));
|
||||
LOG_WARNING(log_, user->getName() + ": " + msg + formatSkippedMessage(args...));
|
||||
};
|
||||
|
||||
if (readonly && calculateResultAccess(false, allow_ddl, allow_introspection)->isGranted(access, args...))
|
||||
if (grant_option && calculateResultAccess(false, readonly, allow_ddl, allow_introspection)->isGranted(access, args...))
|
||||
{
|
||||
show_error(
|
||||
"Not enough privileges. "
|
||||
"The required privileges have been granted, but without grant option. "
|
||||
"To execute this query it's necessary to have the grant "
|
||||
+ AccessRightsElement{access, args...}.toString() + " WITH GRANT OPTION",
|
||||
ErrorCodes::ACCESS_DENIED);
|
||||
}
|
||||
else if (readonly && calculateResultAccess(false, false, allow_ddl, allow_introspection)->isGranted(access, args...))
|
||||
{
|
||||
if (interface == ClientInfo::Interface::HTTP && http_method == ClientInfo::HTTPMethod::GET)
|
||||
show_error(
|
||||
@ -141,108 +150,116 @@ bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & acc
|
||||
else
|
||||
show_error("Cannot execute query in readonly mode", ErrorCodes::READONLY);
|
||||
}
|
||||
else if (!allow_ddl && calculateResultAccess(readonly, true, allow_introspection)->isGranted(access, args...))
|
||||
else if (!allow_ddl && calculateResultAccess(false, readonly, true, allow_introspection)->isGranted(access, args...))
|
||||
{
|
||||
show_error("Cannot execute query. DDL queries are prohibited for the user", ErrorCodes::QUERY_IS_PROHIBITED);
|
||||
}
|
||||
else if (!allow_introspection && calculateResultAccess(readonly, allow_ddl, true)->isGranted(access, args...))
|
||||
else if (!allow_introspection && calculateResultAccess(false, readonly, allow_ddl, true)->isGranted(access, args...))
|
||||
{
|
||||
show_error("Introspection functions are disabled, because setting 'allow_introspection_functions' is set to 0", ErrorCodes::FUNCTION_NOT_ALLOWED);
|
||||
}
|
||||
else
|
||||
{
|
||||
show_error(
|
||||
user_name + ": Not enough privileges. To perform this operation you should have grant "
|
||||
+ AccessRightsElement{access, args...}.toString(),
|
||||
"Not enough privileges. To execute this query it's necessary to have the grant "
|
||||
+ AccessRightsElement{access, args...}.toString() + (grant_option ? " WITH GRANT OPTION" : ""),
|
||||
ErrorCodes::ACCESS_DENIED);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template <int mode>
|
||||
|
||||
template <int mode, bool grant_option>
|
||||
bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessRightsElement & element) const
|
||||
{
|
||||
if (element.any_database)
|
||||
{
|
||||
return checkImpl<mode>(log_, element.access_flags);
|
||||
return checkImpl<mode, grant_option>(log_, element.access_flags);
|
||||
}
|
||||
else if (element.any_table)
|
||||
{
|
||||
if (element.database.empty())
|
||||
return checkImpl<mode>(log_, element.access_flags, current_database);
|
||||
return checkImpl<mode, grant_option>(log_, element.access_flags, current_database);
|
||||
else
|
||||
return checkImpl<mode>(log_, element.access_flags, element.database);
|
||||
return checkImpl<mode, grant_option>(log_, element.access_flags, element.database);
|
||||
}
|
||||
else if (element.any_column)
|
||||
{
|
||||
if (element.database.empty())
|
||||
return checkImpl<mode>(log_, element.access_flags, current_database, element.table);
|
||||
return checkImpl<mode, grant_option>(log_, element.access_flags, current_database, element.table);
|
||||
else
|
||||
return checkImpl<mode>(log_, element.access_flags, element.database, element.table);
|
||||
return checkImpl<mode, grant_option>(log_, element.access_flags, element.database, element.table);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (element.database.empty())
|
||||
return checkImpl<mode>(log_, element.access_flags, current_database, element.table, element.columns);
|
||||
return checkImpl<mode, grant_option>(log_, element.access_flags, current_database, element.table, element.columns);
|
||||
else
|
||||
return checkImpl<mode>(log_, element.access_flags, element.database, element.table, element.columns);
|
||||
return checkImpl<mode, grant_option>(log_, element.access_flags, element.database, element.table, element.columns);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <int mode>
|
||||
template <int mode, bool grant_option>
|
||||
bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessRightsElements & elements) const
|
||||
{
|
||||
for (const auto & element : elements)
|
||||
if (!checkImpl<mode>(log_, element))
|
||||
if (!checkImpl<mode, grant_option>(log_, element))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
void AccessRightsContext::check(const AccessFlags & access) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access); }
|
||||
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database); }
|
||||
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database, table); }
|
||||
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database, table, column); }
|
||||
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database, table, columns); }
|
||||
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database, table, columns); }
|
||||
void AccessRightsContext::check(const AccessRightsElement & access) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access); }
|
||||
void AccessRightsContext::check(const AccessRightsElements & access) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access); }
|
||||
void AccessRightsContext::check(const AccessFlags & access) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); }
|
||||
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database); }
|
||||
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table); }
|
||||
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, column); }
|
||||
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
|
||||
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
|
||||
void AccessRightsContext::check(const AccessRightsElement & access) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); }
|
||||
void AccessRightsContext::check(const AccessRightsElements & access) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); }
|
||||
|
||||
bool AccessRightsContext::isGranted(const AccessFlags & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access); }
|
||||
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database); }
|
||||
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database, table); }
|
||||
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database, table, column); }
|
||||
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database, table, columns); }
|
||||
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database, table, columns); }
|
||||
bool AccessRightsContext::isGranted(const AccessRightsElement & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access); }
|
||||
bool AccessRightsContext::isGranted(const AccessRightsElements & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access); }
|
||||
bool AccessRightsContext::isGranted(const AccessFlags & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); }
|
||||
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database); }
|
||||
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table); }
|
||||
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, column); }
|
||||
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
|
||||
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
|
||||
bool AccessRightsContext::isGranted(const AccessRightsElement & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); }
|
||||
bool AccessRightsContext::isGranted(const AccessRightsElements & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); }
|
||||
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database, table); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database, table, column); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database, table, columns); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database, table, columns); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElement & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElements & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, column); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, columns); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, columns); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElement & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); }
|
||||
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElements & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); }
|
||||
|
||||
void AccessRightsContext::checkGrantOption(const AccessFlags & access) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); }
|
||||
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database); }
|
||||
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table); }
|
||||
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, column); }
|
||||
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, columns); }
|
||||
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, columns); }
|
||||
void AccessRightsContext::checkGrantOption(const AccessRightsElement & access) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); }
|
||||
void AccessRightsContext::checkGrantOption(const AccessRightsElements & access) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); }
|
||||
|
||||
|
||||
boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess() const
|
||||
boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess(bool grant_option) const
|
||||
{
|
||||
auto res = result_access_cache[0].load();
|
||||
if (res)
|
||||
return res;
|
||||
return calculateResultAccess(readonly, allow_ddl, allow_introspection);
|
||||
return calculateResultAccess(grant_option, readonly, allow_ddl, allow_introspection);
|
||||
}
|
||||
|
||||
|
||||
boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess(UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const
|
||||
boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const
|
||||
{
|
||||
size_t cache_index = static_cast<size_t>(readonly_ != readonly)
|
||||
+ static_cast<size_t>(allow_ddl_ != allow_ddl) * 2 +
|
||||
+ static_cast<size_t>(allow_introspection_ != allow_introspection) * 3;
|
||||
+ static_cast<size_t>(allow_introspection_ != allow_introspection) * 3
|
||||
+ static_cast<size_t>(grant_option) * 4;
|
||||
assert(cache_index < std::size(result_access_cache));
|
||||
auto cached = result_access_cache[cache_index].load();
|
||||
if (cached)
|
||||
@ -256,7 +273,7 @@ boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess
|
||||
auto result_ptr = boost::make_shared<AccessRights>();
|
||||
auto & result = *result_ptr;
|
||||
|
||||
result = granted_to_user;
|
||||
result = grant_option ? user->access_with_grant_option : user->access;
|
||||
|
||||
static const AccessFlags table_ddl = AccessType::CREATE_DATABASE | AccessType::CREATE_TABLE | AccessType::CREATE_VIEW
|
||||
| AccessType::ALTER_TABLE | AccessType::ALTER_VIEW | AccessType::DROP_DATABASE | AccessType::DROP_TABLE | AccessType::DROP_VIEW
|
||||
@ -265,12 +282,18 @@ boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess
|
||||
static const AccessFlags table_and_dictionary_ddl = table_ddl | dictionary_ddl;
|
||||
static const AccessFlags write_table_access = AccessType::INSERT | AccessType::OPTIMIZE;
|
||||
|
||||
/// Anyone has access to the "system" database.
|
||||
result.grant(AccessType::SELECT, "system");
|
||||
|
||||
if (readonly_)
|
||||
result.fullRevoke(write_table_access | AccessType::SYSTEM);
|
||||
|
||||
if (readonly_ || !allow_ddl_)
|
||||
result.fullRevoke(table_and_dictionary_ddl);
|
||||
|
||||
if (readonly_ && grant_option)
|
||||
result.fullRevoke(AccessType::ALL);
|
||||
|
||||
if (readonly_ == 1)
|
||||
{
|
||||
/// Table functions are forbidden in readonly mode.
|
||||
@ -282,7 +305,11 @@ boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess
|
||||
result.fullRevoke(AccessType::INTROSPECTION);
|
||||
|
||||
result_access_cache[cache_index].store(result_ptr);
|
||||
return std::move(result_ptr);
|
||||
|
||||
if (trace_log && (readonly == readonly_) && (allow_ddl == allow_ddl_) && (allow_introspection == allow_introspection_))
|
||||
LOG_TRACE(trace_log, "List of all grants: " << result_ptr->toString() << (grant_option ? " WITH GRANT OPTION" : ""));
|
||||
|
||||
return result_ptr;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -11,6 +11,8 @@ namespace Poco { class Logger; }
|
||||
namespace DB
|
||||
{
|
||||
struct Settings;
|
||||
struct User;
|
||||
using UserPtr = std::shared_ptr<const User>;
|
||||
|
||||
|
||||
class AccessRightsContext
|
||||
@ -19,7 +21,7 @@ public:
|
||||
/// Default constructor creates access rights' context which allows everything.
|
||||
AccessRightsContext();
|
||||
|
||||
AccessRightsContext(const ClientInfo & client_info_, const AccessRights & granted_to_user, const Settings & settings, const String & current_database_);
|
||||
AccessRightsContext(const UserPtr & user_, const ClientInfo & client_info_, const Settings & settings, const String & current_database_);
|
||||
|
||||
/// Checks if a specified access granted, and throws an exception if not.
|
||||
/// Empty database means the current database.
|
||||
@ -52,21 +54,30 @@ public:
|
||||
bool isGranted(Poco::Logger * log_, const AccessRightsElement & access) const;
|
||||
bool isGranted(Poco::Logger * log_, const AccessRightsElements & access) const;
|
||||
|
||||
/// Checks if a specified access granted with grant option, and throws an exception if not.
|
||||
void checkGrantOption(const AccessFlags & access) const;
|
||||
void checkGrantOption(const AccessFlags & access, const std::string_view & database) const;
|
||||
void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const;
|
||||
void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const;
|
||||
void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const;
|
||||
void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const;
|
||||
void checkGrantOption(const AccessRightsElement & access) const;
|
||||
void checkGrantOption(const AccessRightsElements & access) const;
|
||||
|
||||
private:
|
||||
template <int mode, typename... Args>
|
||||
template <int mode, bool grant_option, typename... Args>
|
||||
bool checkImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const;
|
||||
|
||||
template <int mode>
|
||||
template <int mode, bool grant_option>
|
||||
bool checkImpl(Poco::Logger * log_, const AccessRightsElement & access) const;
|
||||
|
||||
template <int mode>
|
||||
template <int mode, bool grant_option>
|
||||
bool checkImpl(Poco::Logger * log_, const AccessRightsElements & access) const;
|
||||
|
||||
boost::shared_ptr<const AccessRights> calculateResultAccess() const;
|
||||
boost::shared_ptr<const AccessRights> calculateResultAccess(UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const;
|
||||
boost::shared_ptr<const AccessRights> calculateResultAccess(bool grant_option) const;
|
||||
boost::shared_ptr<const AccessRights> calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const;
|
||||
|
||||
const String user_name;
|
||||
const AccessRights granted_to_user;
|
||||
const UserPtr user;
|
||||
const UInt64 readonly = 0;
|
||||
const bool allow_ddl = true;
|
||||
const bool allow_introspection = true;
|
||||
@ -74,7 +85,7 @@ private:
|
||||
const ClientInfo::Interface interface = ClientInfo::Interface::TCP;
|
||||
const ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN;
|
||||
Poco::Logger * const trace_log = nullptr;
|
||||
mutable boost::atomic_shared_ptr<const AccessRights> result_access_cache[4];
|
||||
mutable boost::atomic_shared_ptr<const AccessRights> result_access_cache[7];
|
||||
mutable std::mutex mutex;
|
||||
};
|
||||
|
||||
|
@ -1,15 +1,12 @@
|
||||
#include <Access/AllowedClientHosts.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <common/SimpleCache.h>
|
||||
#include <Common/StringUtils/StringUtils.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <Functions/likePatternToRegexp.h>
|
||||
#include <Poco/Net/SocketAddress.h>
|
||||
#include <Poco/RegularExpression.h>
|
||||
#include <common/logger_useful.h>
|
||||
#include <ext/scope_guard.h>
|
||||
#include <boost/range/algorithm/find.hpp>
|
||||
#include <boost/range/algorithm/find_first_of.hpp>
|
||||
#include <boost/algorithm/string/predicate.hpp>
|
||||
#include <boost/algorithm/string/replace.hpp>
|
||||
#include <ifaddrs.h>
|
||||
|
||||
|
||||
@ -27,20 +24,6 @@ namespace
|
||||
using IPSubnet = AllowedClientHosts::IPSubnet;
|
||||
const IPSubnet ALL_ADDRESSES{IPAddress{IPAddress::IPv6}, IPAddress{IPAddress::IPv6}};
|
||||
|
||||
const IPAddress & getIPV6Loopback()
|
||||
{
|
||||
static const IPAddress ip("::1");
|
||||
return ip;
|
||||
}
|
||||
|
||||
bool isIPV4LoopbackMappedToIPV6(const IPAddress & ip)
|
||||
{
|
||||
static const IPAddress prefix("::ffff:127.0.0.0");
|
||||
/// 104 == 128 - 24, we have to reset the lowest 24 bits of 128 before comparing with `prefix`
|
||||
/// (IPv4 loopback means any IP from 127.0.0.0 to 127.255.255.255).
|
||||
return (ip & IPAddress(104, IPAddress::IPv6)) == prefix;
|
||||
}
|
||||
|
||||
/// Converts an address to IPv6.
|
||||
/// The loopback address "127.0.0.1" (or any "127.x.y.z") is converted to "::1".
|
||||
IPAddress toIPv6(const IPAddress & ip)
|
||||
@ -52,35 +35,18 @@ namespace
|
||||
v6 = IPAddress("::ffff:" + ip.toString());
|
||||
|
||||
// ::ffff:127.XX.XX.XX -> ::1
|
||||
if (isIPV4LoopbackMappedToIPV6(v6))
|
||||
v6 = getIPV6Loopback();
|
||||
if ((v6 & IPAddress(104, IPAddress::IPv6)) == IPAddress("::ffff:127.0.0.0"))
|
||||
v6 = IPAddress{"::1"};
|
||||
|
||||
return v6;
|
||||
}
|
||||
|
||||
/// Converts a subnet to IPv6.
|
||||
IPSubnet toIPv6(const IPSubnet & subnet)
|
||||
{
|
||||
IPSubnet v6;
|
||||
if (subnet.prefix.family() == IPAddress::IPv6)
|
||||
v6.prefix = subnet.prefix;
|
||||
else
|
||||
v6.prefix = IPAddress("::ffff:" + subnet.prefix.toString());
|
||||
|
||||
if (subnet.mask.family() == IPAddress::IPv6)
|
||||
v6.mask = subnet.mask;
|
||||
else
|
||||
v6.mask = IPAddress(96, IPAddress::IPv6) | IPAddress("::ffff:" + subnet.mask.toString());
|
||||
|
||||
v6.prefix = v6.prefix & v6.mask;
|
||||
|
||||
// ::ffff:127.XX.XX.XX -> ::1
|
||||
if (isIPV4LoopbackMappedToIPV6(v6.prefix))
|
||||
v6 = {getIPV6Loopback(), IPAddress(128, IPAddress::IPv6)};
|
||||
|
||||
return v6;
|
||||
return IPSubnet(toIPv6(subnet.getPrefix()), subnet.getMask());
|
||||
}
|
||||
|
||||
|
||||
/// Helper function for isAddressOfHost().
|
||||
bool isAddressOfHostImpl(const IPAddress & address, const String & host)
|
||||
{
|
||||
@ -150,7 +116,7 @@ namespace
|
||||
|
||||
int err = getifaddrs(&ifa_begin);
|
||||
if (err)
|
||||
return {getIPV6Loopback()};
|
||||
return {IPAddress{"::1"}};
|
||||
|
||||
for (const ifaddrs * ifa = ifa_begin; ifa; ifa = ifa->ifa_next)
|
||||
{
|
||||
@ -203,163 +169,203 @@ namespace
|
||||
static SimpleCache<decltype(getHostByAddressImpl), &getHostByAddressImpl> cache;
|
||||
return cache(address);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
String AllowedClientHosts::IPSubnet::toString() const
|
||||
{
|
||||
unsigned int prefix_length = mask.prefixLength();
|
||||
if (IPAddress{prefix_length, mask.family()} == mask)
|
||||
return prefix.toString() + "/" + std::to_string(prefix_length);
|
||||
|
||||
return prefix.toString() + "/" + mask.toString();
|
||||
}
|
||||
|
||||
|
||||
AllowedClientHosts::AllowedClientHosts()
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
AllowedClientHosts::AllowedClientHosts(AllAddressesTag)
|
||||
{
|
||||
addAllAddresses();
|
||||
}
|
||||
|
||||
|
||||
AllowedClientHosts::~AllowedClientHosts() = default;
|
||||
|
||||
|
||||
AllowedClientHosts::AllowedClientHosts(const AllowedClientHosts & src)
|
||||
{
|
||||
*this = src;
|
||||
}
|
||||
|
||||
|
||||
AllowedClientHosts & AllowedClientHosts::operator =(const AllowedClientHosts & src)
|
||||
{
|
||||
addresses = src.addresses;
|
||||
localhost = src.localhost;
|
||||
subnets = src.subnets;
|
||||
host_names = src.host_names;
|
||||
host_regexps = src.host_regexps;
|
||||
compiled_host_regexps.clear();
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
AllowedClientHosts::AllowedClientHosts(AllowedClientHosts && src) = default;
|
||||
AllowedClientHosts & AllowedClientHosts::operator =(AllowedClientHosts && src) = default;
|
||||
|
||||
|
||||
void AllowedClientHosts::clear()
|
||||
{
|
||||
addresses.clear();
|
||||
localhost = false;
|
||||
subnets.clear();
|
||||
host_names.clear();
|
||||
host_regexps.clear();
|
||||
compiled_host_regexps.clear();
|
||||
}
|
||||
|
||||
|
||||
bool AllowedClientHosts::empty() const
|
||||
{
|
||||
return addresses.empty() && subnets.empty() && host_names.empty() && host_regexps.empty();
|
||||
}
|
||||
|
||||
|
||||
void AllowedClientHosts::addAddress(const IPAddress & address)
|
||||
{
|
||||
IPAddress addr_v6 = toIPv6(address);
|
||||
if (boost::range::find(addresses, addr_v6) != addresses.end())
|
||||
return;
|
||||
addresses.push_back(addr_v6);
|
||||
if (addr_v6.isLoopback())
|
||||
localhost = true;
|
||||
}
|
||||
|
||||
|
||||
void AllowedClientHosts::addAddress(const String & address)
|
||||
{
|
||||
addAddress(IPAddress{address});
|
||||
}
|
||||
|
||||
|
||||
void AllowedClientHosts::addSubnet(const IPSubnet & subnet)
|
||||
{
|
||||
IPSubnet subnet_v6 = toIPv6(subnet);
|
||||
|
||||
if (subnet_v6.mask == IPAddress(128, IPAddress::IPv6))
|
||||
void parseLikePatternIfIPSubnet(const String & pattern, IPSubnet & subnet, IPAddress::Family address_family)
|
||||
{
|
||||
addAddress(subnet_v6.prefix);
|
||||
return;
|
||||
size_t slash = pattern.find('/');
|
||||
if (slash != String::npos)
|
||||
{
|
||||
/// IP subnet, e.g. "192.168.0.0/16" or "192.168.0.0/255.255.0.0".
|
||||
subnet = IPSubnet{pattern};
|
||||
return;
|
||||
}
|
||||
|
||||
bool has_wildcard = (pattern.find_first_of("%_") != String::npos);
|
||||
if (has_wildcard)
|
||||
{
|
||||
/// IP subnet specified with one of the wildcard characters, e.g. "192.168.%.%".
|
||||
String wildcard_replaced_with_zero_bits = pattern;
|
||||
String wildcard_replaced_with_one_bits = pattern;
|
||||
if (address_family == IPAddress::IPv6)
|
||||
{
|
||||
boost::algorithm::replace_all(wildcard_replaced_with_zero_bits, "_", "0");
|
||||
boost::algorithm::replace_all(wildcard_replaced_with_zero_bits, "%", "0000");
|
||||
boost::algorithm::replace_all(wildcard_replaced_with_one_bits, "_", "f");
|
||||
boost::algorithm::replace_all(wildcard_replaced_with_one_bits, "%", "ffff");
|
||||
}
|
||||
else if (address_family == IPAddress::IPv4)
|
||||
{
|
||||
boost::algorithm::replace_all(wildcard_replaced_with_zero_bits, "%", "0");
|
||||
boost::algorithm::replace_all(wildcard_replaced_with_one_bits, "%", "255");
|
||||
}
|
||||
|
||||
IPAddress prefix{wildcard_replaced_with_zero_bits};
|
||||
IPAddress mask = ~(prefix ^ IPAddress{wildcard_replaced_with_one_bits});
|
||||
subnet = IPSubnet{prefix, mask};
|
||||
return;
|
||||
}
|
||||
|
||||
/// Exact IP address.
|
||||
subnet = IPSubnet{pattern};
|
||||
}
|
||||
|
||||
if (boost::range::find(subnets, subnet_v6) == subnets.end())
|
||||
subnets.push_back(subnet_v6);
|
||||
}
|
||||
|
||||
|
||||
void AllowedClientHosts::addSubnet(const IPAddress & prefix, const IPAddress & mask)
|
||||
{
|
||||
addSubnet(IPSubnet{prefix, mask});
|
||||
}
|
||||
|
||||
|
||||
void AllowedClientHosts::addSubnet(const IPAddress & prefix, size_t num_prefix_bits)
|
||||
{
|
||||
addSubnet(prefix, IPAddress(num_prefix_bits, prefix.family()));
|
||||
}
|
||||
|
||||
|
||||
void AllowedClientHosts::addSubnet(const String & subnet)
|
||||
{
|
||||
size_t slash = subnet.find('/');
|
||||
if (slash == String::npos)
|
||||
/// Extracts a subnet, a host name or a host name regular expession from a like pattern.
|
||||
void parseLikePattern(
|
||||
const String & pattern, std::optional<IPSubnet> & subnet, std::optional<String> & name, std::optional<String> & name_regexp)
|
||||
{
|
||||
addAddress(subnet);
|
||||
return;
|
||||
/// If `host` starts with digits and a dot then it's an IP pattern, otherwise it's a hostname pattern.
|
||||
size_t first_not_digit = pattern.find_first_not_of("0123456789");
|
||||
if ((first_not_digit != String::npos) && (first_not_digit != 0) && (pattern[first_not_digit] == '.'))
|
||||
{
|
||||
parseLikePatternIfIPSubnet(pattern, subnet.emplace(), IPAddress::IPv4);
|
||||
return;
|
||||
}
|
||||
|
||||
size_t first_not_hex = pattern.find_first_not_of("0123456789ABCDEFabcdef");
|
||||
if (((first_not_hex == 4) && pattern[first_not_hex] == ':') || pattern.starts_with("::"))
|
||||
{
|
||||
parseLikePatternIfIPSubnet(pattern, subnet.emplace(), IPAddress::IPv6);
|
||||
return;
|
||||
}
|
||||
|
||||
bool has_wildcard = (pattern.find_first_of("%_") != String::npos);
|
||||
if (has_wildcard)
|
||||
{
|
||||
name_regexp = likePatternToRegexp(pattern);
|
||||
return;
|
||||
}
|
||||
|
||||
name = pattern;
|
||||
}
|
||||
|
||||
IPAddress prefix{String{subnet, 0, slash}};
|
||||
String mask(subnet, slash + 1, subnet.length() - slash - 1);
|
||||
if (std::all_of(mask.begin(), mask.end(), isNumericASCII))
|
||||
addSubnet(prefix, parseFromString<UInt8>(mask));
|
||||
else
|
||||
addSubnet(prefix, IPAddress{mask});
|
||||
}
|
||||
|
||||
|
||||
void AllowedClientHosts::addHostName(const String & host_name)
|
||||
bool AllowedClientHosts::contains(const IPAddress & client_address) const
|
||||
{
|
||||
if (boost::range::find(host_names, host_name) != host_names.end())
|
||||
return;
|
||||
host_names.push_back(host_name);
|
||||
if (boost::iequals(host_name, "localhost"))
|
||||
localhost = true;
|
||||
}
|
||||
if (any_host)
|
||||
return true;
|
||||
|
||||
IPAddress client_v6 = toIPv6(client_address);
|
||||
|
||||
void AllowedClientHosts::addHostRegexp(const String & host_regexp)
|
||||
{
|
||||
if (boost::range::find(host_regexps, host_regexp) == host_regexps.end())
|
||||
host_regexps.push_back(host_regexp);
|
||||
}
|
||||
std::optional<bool> is_client_local_value;
|
||||
auto is_client_local = [&]
|
||||
{
|
||||
if (is_client_local_value)
|
||||
return *is_client_local_value;
|
||||
is_client_local_value = isAddressOfLocalhost(client_v6);
|
||||
return *is_client_local_value;
|
||||
};
|
||||
|
||||
if (local_host && is_client_local())
|
||||
return true;
|
||||
|
||||
void AllowedClientHosts::addAllAddresses()
|
||||
{
|
||||
clear();
|
||||
addSubnet(ALL_ADDRESSES);
|
||||
}
|
||||
/// Check `addresses`.
|
||||
auto check_address = [&](const IPAddress & address_)
|
||||
{
|
||||
IPAddress address_v6 = toIPv6(address_);
|
||||
if (address_v6.isLoopback())
|
||||
return is_client_local();
|
||||
return address_v6 == client_v6;
|
||||
};
|
||||
|
||||
for (const auto & address : addresses)
|
||||
if (check_address(address))
|
||||
return true;
|
||||
|
||||
bool AllowedClientHosts::containsAllAddresses() const
|
||||
{
|
||||
return (boost::range::find(subnets, ALL_ADDRESSES) != subnets.end())
|
||||
|| (boost::range::find(host_regexps, ".*") != host_regexps.end())
|
||||
|| (boost::range::find(host_regexps, "$") != host_regexps.end());
|
||||
/// Check `subnets`.
|
||||
auto check_subnet = [&](const IPSubnet & subnet_)
|
||||
{
|
||||
IPSubnet subnet_v6 = toIPv6(subnet_);
|
||||
if (subnet_v6.isMaskAllBitsOne())
|
||||
return check_address(subnet_v6.getPrefix());
|
||||
return (client_v6 & subnet_v6.getMask()) == subnet_v6.getPrefix();
|
||||
};
|
||||
|
||||
for (const auto & subnet : subnets)
|
||||
if (check_subnet(subnet))
|
||||
return true;
|
||||
|
||||
/// Check `names`.
|
||||
auto check_name = [&](const String & name_)
|
||||
{
|
||||
if (boost::iequals(name_, "localhost"))
|
||||
return is_client_local();
|
||||
try
|
||||
{
|
||||
return isAddressOfHost(client_v6, name_);
|
||||
}
|
||||
catch (const Exception & e)
|
||||
{
|
||||
if (e.code() != ErrorCodes::DNS_ERROR)
|
||||
throw;
|
||||
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
|
||||
LOG_WARNING(
|
||||
&Logger::get("AddressPatterns"),
|
||||
"Failed to check if the allowed client hosts contain address " << client_address.toString() << ". " << e.displayText()
|
||||
<< ", code = " << e.code());
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
for (const String & name : names)
|
||||
if (check_name(name))
|
||||
return true;
|
||||
|
||||
/// Check `name_regexps`.
|
||||
std::optional<String> resolved_host;
|
||||
auto check_name_regexp = [&](const String & name_regexp_)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (boost::iequals(name_regexp_, "localhost"))
|
||||
return is_client_local();
|
||||
if (!resolved_host)
|
||||
resolved_host = getHostByAddress(client_v6);
|
||||
if (resolved_host->empty())
|
||||
return false;
|
||||
Poco::RegularExpression re(name_regexp_);
|
||||
Poco::RegularExpression::Match match;
|
||||
return re.match(*resolved_host, match) != 0;
|
||||
}
|
||||
catch (const Exception & e)
|
||||
{
|
||||
if (e.code() != ErrorCodes::DNS_ERROR)
|
||||
throw;
|
||||
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
|
||||
LOG_WARNING(
|
||||
&Logger::get("AddressPatterns"),
|
||||
"Failed to check if the allowed client hosts contain address " << client_address.toString() << ". " << e.displayText()
|
||||
<< ", code = " << e.code());
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
for (const String & name_regexp : name_regexps)
|
||||
if (check_name_regexp(name_regexp))
|
||||
return true;
|
||||
|
||||
auto check_like_pattern = [&](const String & pattern)
|
||||
{
|
||||
std::optional<IPSubnet> subnet;
|
||||
std::optional<String> name;
|
||||
std::optional<String> name_regexp;
|
||||
parseLikePattern(pattern, subnet, name, name_regexp);
|
||||
if (subnet)
|
||||
return check_subnet(*subnet);
|
||||
else if (name)
|
||||
return check_name(*name);
|
||||
else if (name_regexp)
|
||||
return check_name_regexp(*name_regexp);
|
||||
else
|
||||
return false;
|
||||
};
|
||||
|
||||
for (const String & like_pattern : like_patterns)
|
||||
if (check_like_pattern(like_pattern))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@ -374,86 +380,4 @@ void AllowedClientHosts::checkContains(const IPAddress & address, const String &
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
bool AllowedClientHosts::contains(const IPAddress & address) const
|
||||
{
|
||||
/// Check `ip_addresses`.
|
||||
IPAddress addr_v6 = toIPv6(address);
|
||||
if (boost::range::find(addresses, addr_v6) != addresses.end())
|
||||
return true;
|
||||
|
||||
if (localhost && isAddressOfLocalhost(addr_v6))
|
||||
return true;
|
||||
|
||||
/// Check `ip_subnets`.
|
||||
for (const auto & subnet : subnets)
|
||||
if ((addr_v6 & subnet.mask) == subnet.prefix)
|
||||
return true;
|
||||
|
||||
/// Check `hosts`.
|
||||
for (const String & host_name : host_names)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (isAddressOfHost(addr_v6, host_name))
|
||||
return true;
|
||||
}
|
||||
catch (const Exception & e)
|
||||
{
|
||||
if (e.code() != ErrorCodes::DNS_ERROR)
|
||||
throw;
|
||||
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
|
||||
LOG_WARNING(
|
||||
&Logger::get("AddressPatterns"),
|
||||
"Failed to check if the allowed client hosts contain address " << address.toString() << ". " << e.displayText()
|
||||
<< ", code = " << e.code());
|
||||
}
|
||||
}
|
||||
|
||||
/// Check `host_regexps`.
|
||||
try
|
||||
{
|
||||
String resolved_host = getHostByAddress(addr_v6);
|
||||
if (!resolved_host.empty())
|
||||
{
|
||||
compileRegexps();
|
||||
for (const auto & compiled_regexp : compiled_host_regexps)
|
||||
{
|
||||
Poco::RegularExpression::Match match;
|
||||
if (compiled_regexp && compiled_regexp->match(resolved_host, match))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (const Exception & e)
|
||||
{
|
||||
if (e.code() != ErrorCodes::DNS_ERROR)
|
||||
throw;
|
||||
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
|
||||
LOG_WARNING(
|
||||
&Logger::get("AddressPatterns"),
|
||||
"Failed to check if the allowed client hosts contain address " << address.toString() << ". " << e.displayText()
|
||||
<< ", code = " << e.code());
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
void AllowedClientHosts::compileRegexps() const
|
||||
{
|
||||
if (compiled_host_regexps.size() == host_regexps.size())
|
||||
return;
|
||||
size_t old_size = compiled_host_regexps.size();
|
||||
compiled_host_regexps.reserve(host_regexps.size());
|
||||
for (size_t i = old_size; i != host_regexps.size(); ++i)
|
||||
compiled_host_regexps.emplace_back(std::make_unique<Poco::RegularExpression>(host_regexps[i]));
|
||||
}
|
||||
|
||||
|
||||
bool operator ==(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs)
|
||||
{
|
||||
return (lhs.addresses == rhs.addresses) && (lhs.subnets == rhs.subnets) && (lhs.host_names == rhs.host_names)
|
||||
&& (lhs.host_regexps == rhs.host_regexps);
|
||||
}
|
||||
}
|
||||
|
@ -4,12 +4,9 @@
|
||||
#include <Poco/Net/IPAddress.h>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
|
||||
namespace Poco
|
||||
{
|
||||
class RegularExpression;
|
||||
}
|
||||
#include <boost/range/algorithm/find.hpp>
|
||||
#include <boost/range/algorithm_ext/erase.hpp>
|
||||
#include <boost/algorithm/string/predicate.hpp>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -20,69 +17,100 @@ class AllowedClientHosts
|
||||
public:
|
||||
using IPAddress = Poco::Net::IPAddress;
|
||||
|
||||
struct IPSubnet
|
||||
class IPSubnet
|
||||
{
|
||||
IPAddress prefix;
|
||||
IPAddress mask;
|
||||
public:
|
||||
IPSubnet() {}
|
||||
IPSubnet(const IPAddress & prefix_, const IPAddress & mask_) { set(prefix_, mask_); }
|
||||
IPSubnet(const IPAddress & prefix_, size_t num_prefix_bits) { set(prefix_, num_prefix_bits); }
|
||||
explicit IPSubnet(const IPAddress & address) { set(address); }
|
||||
explicit IPSubnet(const String & str);
|
||||
|
||||
const IPAddress & getPrefix() const { return prefix; }
|
||||
const IPAddress & getMask() const { return mask; }
|
||||
bool isMaskAllBitsOne() const;
|
||||
String toString() const;
|
||||
|
||||
friend bool operator ==(const IPSubnet & lhs, const IPSubnet & rhs) { return (lhs.prefix == rhs.prefix) && (lhs.mask == rhs.mask); }
|
||||
friend bool operator !=(const IPSubnet & lhs, const IPSubnet & rhs) { return !(lhs == rhs); }
|
||||
|
||||
private:
|
||||
void set(const IPAddress & prefix_, const IPAddress & mask_);
|
||||
void set(const IPAddress & prefix_, size_t num_prefix_bits);
|
||||
void set(const IPAddress & address);
|
||||
|
||||
IPAddress prefix;
|
||||
IPAddress mask;
|
||||
};
|
||||
|
||||
struct AllAddressesTag {};
|
||||
struct AnyHostTag {};
|
||||
|
||||
AllowedClientHosts();
|
||||
explicit AllowedClientHosts(AllAddressesTag);
|
||||
~AllowedClientHosts();
|
||||
AllowedClientHosts() {}
|
||||
explicit AllowedClientHosts(AnyHostTag) { addAnyHost(); }
|
||||
~AllowedClientHosts() {}
|
||||
|
||||
AllowedClientHosts(const AllowedClientHosts & src);
|
||||
AllowedClientHosts & operator =(const AllowedClientHosts & src);
|
||||
AllowedClientHosts(AllowedClientHosts && src);
|
||||
AllowedClientHosts & operator =(AllowedClientHosts && src);
|
||||
AllowedClientHosts(const AllowedClientHosts & src) = default;
|
||||
AllowedClientHosts & operator =(const AllowedClientHosts & src) = default;
|
||||
AllowedClientHosts(AllowedClientHosts && src) = default;
|
||||
AllowedClientHosts & operator =(AllowedClientHosts && src) = default;
|
||||
|
||||
/// Removes all contained addresses. This will disallow all addresses.
|
||||
/// Removes all contained addresses. This will disallow all hosts.
|
||||
void clear();
|
||||
|
||||
bool empty() const;
|
||||
|
||||
/// Allows exact IP address.
|
||||
/// For example, 213.180.204.3 or 2a02:6b8::3
|
||||
void addAddress(const IPAddress & address);
|
||||
void addAddress(const String & address);
|
||||
|
||||
/// Allows an IP subnet.
|
||||
void addSubnet(const IPSubnet & subnet);
|
||||
void addSubnet(const String & subnet);
|
||||
|
||||
/// Allows an IP subnet.
|
||||
/// For example, 312.234.1.1/255.255.255.0 or 2a02:6b8::3/FFFF:FFFF:FFFF:FFFF::
|
||||
void addSubnet(const IPAddress & prefix, const IPAddress & mask);
|
||||
|
||||
/// Allows an IP subnet.
|
||||
/// For example, 10.0.0.1/8 or 2a02:6b8::3/64
|
||||
void addSubnet(const IPAddress & prefix, size_t num_prefix_bits);
|
||||
|
||||
/// Allows all addresses.
|
||||
void addAllAddresses();
|
||||
|
||||
/// Allows an exact host. The `contains()` function will check that the provided address equals to one of that host's addresses.
|
||||
void addHostName(const String & host_name);
|
||||
|
||||
/// Allows a regular expression for the host.
|
||||
void addHostRegexp(const String & host_regexp);
|
||||
|
||||
void addAddress(const String & address) { addAddress(IPAddress(address)); }
|
||||
void removeAddress(const IPAddress & address);
|
||||
void removeAddress(const String & address) { removeAddress(IPAddress{address}); }
|
||||
const std::vector<IPAddress> & getAddresses() const { return addresses; }
|
||||
|
||||
/// Allows an IP subnet.
|
||||
/// For example, 312.234.1.1/255.255.255.0 or 2a02:6b8::3/64
|
||||
void addSubnet(const IPSubnet & subnet);
|
||||
void addSubnet(const String & subnet) { addSubnet(IPSubnet{subnet}); }
|
||||
void addSubnet(const IPAddress & prefix, const IPAddress & mask) { addSubnet({prefix, mask}); }
|
||||
void addSubnet(const IPAddress & prefix, size_t num_prefix_bits) { addSubnet({prefix, num_prefix_bits}); }
|
||||
void removeSubnet(const IPSubnet & subnet);
|
||||
void removeSubnet(const String & subnet) { removeSubnet(IPSubnet{subnet}); }
|
||||
void removeSubnet(const IPAddress & prefix, const IPAddress & mask) { removeSubnet({prefix, mask}); }
|
||||
void removeSubnet(const IPAddress & prefix, size_t num_prefix_bits) { removeSubnet({prefix, num_prefix_bits}); }
|
||||
const std::vector<IPSubnet> & getSubnets() const { return subnets; }
|
||||
const std::vector<String> & getHostNames() const { return host_names; }
|
||||
const std::vector<String> & getHostRegexps() const { return host_regexps; }
|
||||
|
||||
/// Allows an exact host name. The `contains()` function will check that the provided address equals to one of that host's addresses.
|
||||
void addName(const String & name);
|
||||
void removeName(const String & name);
|
||||
const std::vector<String> & getNames() const { return names; }
|
||||
|
||||
/// Allows the host names matching a regular expression.
|
||||
void addNameRegexp(const String & name_regexp);
|
||||
void removeNameRegexp(const String & name_regexp);
|
||||
const std::vector<String> & getNameRegexps() const { return name_regexps; }
|
||||
|
||||
/// Allows IP addresses or host names using LIKE pattern.
|
||||
/// This pattern can contain % and _ wildcard characters.
|
||||
/// For example, addLikePattern("@") will allow all addresses.
|
||||
void addLikePattern(const String & pattern);
|
||||
void removeLikePattern(const String & like_pattern);
|
||||
const std::vector<String> & getLikePatterns() const { return like_patterns; }
|
||||
|
||||
/// Allows local host.
|
||||
void addLocalHost();
|
||||
void removeLocalHost();
|
||||
bool containsLocalHost() const { return local_host;}
|
||||
|
||||
/// Allows any host.
|
||||
void addAnyHost();
|
||||
bool containsAnyHost() const { return any_host;}
|
||||
|
||||
void add(const AllowedClientHosts & other);
|
||||
void remove(const AllowedClientHosts & other);
|
||||
|
||||
/// Checks if the provided address is in the list. Returns false if not.
|
||||
bool contains(const IPAddress & address) const;
|
||||
|
||||
/// Checks if any address is allowed.
|
||||
bool containsAllAddresses() const;
|
||||
|
||||
/// Checks if the provided address is in the list. Throws an exception if not.
|
||||
/// `username` is only used for generating an error message if the address isn't in the list.
|
||||
void checkContains(const IPAddress & address, const String & user_name = String()) const;
|
||||
@ -91,13 +119,269 @@ public:
|
||||
friend bool operator !=(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs) { return !(lhs == rhs); }
|
||||
|
||||
private:
|
||||
void compileRegexps() const;
|
||||
|
||||
std::vector<IPAddress> addresses;
|
||||
bool localhost = false;
|
||||
std::vector<IPSubnet> subnets;
|
||||
std::vector<String> host_names;
|
||||
std::vector<String> host_regexps;
|
||||
mutable std::vector<std::unique_ptr<Poco::RegularExpression>> compiled_host_regexps;
|
||||
Strings names;
|
||||
Strings name_regexps;
|
||||
Strings like_patterns;
|
||||
bool any_host = false;
|
||||
bool local_host = false;
|
||||
};
|
||||
|
||||
|
||||
inline void AllowedClientHosts::IPSubnet::set(const IPAddress & prefix_, const IPAddress & mask_)
|
||||
{
|
||||
prefix = prefix_;
|
||||
mask = mask_;
|
||||
|
||||
if (prefix.family() != mask.family())
|
||||
{
|
||||
if (prefix.family() == IPAddress::IPv4)
|
||||
prefix = IPAddress("::ffff:" + prefix.toString());
|
||||
|
||||
if (mask.family() == IPAddress::IPv4)
|
||||
mask = IPAddress(96, IPAddress::IPv6) | IPAddress("::ffff:" + mask.toString());
|
||||
}
|
||||
|
||||
prefix = prefix & mask;
|
||||
|
||||
if (prefix.family() == IPAddress::IPv4)
|
||||
{
|
||||
if ((prefix & IPAddress{8, IPAddress::IPv4}) == IPAddress{"127.0.0.0"})
|
||||
{
|
||||
// 127.XX.XX.XX -> 127.0.0.1
|
||||
prefix = IPAddress{"127.0.0.1"};
|
||||
mask = IPAddress{32, IPAddress::IPv4};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if ((prefix & IPAddress{104, IPAddress::IPv6}) == IPAddress{"::ffff:127.0.0.0"})
|
||||
{
|
||||
// ::ffff:127.XX.XX.XX -> ::1
|
||||
prefix = IPAddress{"::1"};
|
||||
mask = IPAddress{128, IPAddress::IPv6};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::IPSubnet::set(const IPAddress & prefix_, size_t num_prefix_bits)
|
||||
{
|
||||
set(prefix_, IPAddress(num_prefix_bits, prefix_.family()));
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::IPSubnet::set(const IPAddress & address)
|
||||
{
|
||||
set(address, address.length() * 8);
|
||||
}
|
||||
|
||||
inline AllowedClientHosts::IPSubnet::IPSubnet(const String & str)
|
||||
{
|
||||
size_t slash = str.find('/');
|
||||
if (slash == String::npos)
|
||||
{
|
||||
set(IPAddress(str));
|
||||
return;
|
||||
}
|
||||
|
||||
IPAddress new_prefix{String{str, 0, slash}};
|
||||
String mask_str(str, slash + 1, str.length() - slash - 1);
|
||||
bool only_digits = (mask_str.find_first_not_of("0123456789") == std::string::npos);
|
||||
if (only_digits)
|
||||
set(new_prefix, std::stoul(mask_str));
|
||||
else
|
||||
set(new_prefix, IPAddress{mask_str});
|
||||
}
|
||||
|
||||
inline String AllowedClientHosts::IPSubnet::toString() const
|
||||
{
|
||||
unsigned int prefix_length = mask.prefixLength();
|
||||
if (isMaskAllBitsOne())
|
||||
return prefix.toString();
|
||||
else if (IPAddress{prefix_length, mask.family()} == mask)
|
||||
return prefix.toString() + "/" + std::to_string(prefix_length);
|
||||
else
|
||||
return prefix.toString() + "/" + mask.toString();
|
||||
}
|
||||
|
||||
inline bool AllowedClientHosts::IPSubnet::isMaskAllBitsOne() const
|
||||
{
|
||||
return mask == IPAddress(mask.length() * 8, mask.family());
|
||||
}
|
||||
|
||||
|
||||
inline void AllowedClientHosts::clear()
|
||||
{
|
||||
addresses = {};
|
||||
subnets = {};
|
||||
names = {};
|
||||
name_regexps = {};
|
||||
like_patterns = {};
|
||||
any_host = false;
|
||||
local_host = false;
|
||||
}
|
||||
|
||||
inline bool AllowedClientHosts::empty() const
|
||||
{
|
||||
return !any_host && !local_host && addresses.empty() && subnets.empty() && names.empty() && name_regexps.empty() && like_patterns.empty();
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::addAddress(const IPAddress & address)
|
||||
{
|
||||
if (address.isLoopback())
|
||||
local_host = true;
|
||||
else if (boost::range::find(addresses, address) == addresses.end())
|
||||
addresses.push_back(address);
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::removeAddress(const IPAddress & address)
|
||||
{
|
||||
if (address.isLoopback())
|
||||
local_host = false;
|
||||
else
|
||||
boost::range::remove_erase(addresses, address);
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::addSubnet(const IPSubnet & subnet)
|
||||
{
|
||||
if (subnet.getMask().isWildcard())
|
||||
any_host = true;
|
||||
else if (subnet.isMaskAllBitsOne())
|
||||
addAddress(subnet.getPrefix());
|
||||
else if (boost::range::find(subnets, subnet) == subnets.end())
|
||||
subnets.push_back(subnet);
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::removeSubnet(const IPSubnet & subnet)
|
||||
{
|
||||
if (subnet.getMask().isWildcard())
|
||||
any_host = false;
|
||||
else if (subnet.isMaskAllBitsOne())
|
||||
removeAddress(subnet.getPrefix());
|
||||
else
|
||||
boost::range::remove_erase(subnets, subnet);
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::addName(const String & name)
|
||||
{
|
||||
if (boost::iequals(name, "localhost"))
|
||||
local_host = true;
|
||||
else if (boost::range::find(names, name) == names.end())
|
||||
names.push_back(name);
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::removeName(const String & name)
|
||||
{
|
||||
if (boost::iequals(name, "localhost"))
|
||||
local_host = false;
|
||||
else
|
||||
boost::range::remove_erase(names, name);
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::addNameRegexp(const String & name_regexp)
|
||||
{
|
||||
if (boost::iequals(name_regexp, "localhost"))
|
||||
local_host = true;
|
||||
else if (name_regexp == ".*")
|
||||
any_host = true;
|
||||
else if (boost::range::find(name_regexps, name_regexp) == name_regexps.end())
|
||||
name_regexps.push_back(name_regexp);
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::removeNameRegexp(const String & name_regexp)
|
||||
{
|
||||
if (boost::iequals(name_regexp, "localhost"))
|
||||
local_host = false;
|
||||
else if (name_regexp == ".*")
|
||||
any_host = false;
|
||||
else
|
||||
boost::range::remove_erase(name_regexps, name_regexp);
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::addLikePattern(const String & pattern)
|
||||
{
|
||||
if (boost::iequals(pattern, "localhost") || (pattern == "127.0.0.1") || (pattern == "::1"))
|
||||
local_host = true;
|
||||
else if ((pattern == "@") || (pattern == "0.0.0.0/0") || (pattern == "::/0"))
|
||||
any_host = true;
|
||||
else if (boost::range::find(like_patterns, pattern) == name_regexps.end())
|
||||
like_patterns.push_back(pattern);
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::removeLikePattern(const String & pattern)
|
||||
{
|
||||
if (boost::iequals(pattern, "localhost") || (pattern == "127.0.0.1") || (pattern == "::1"))
|
||||
local_host = false;
|
||||
else if ((pattern == "@") || (pattern == "0.0.0.0/0") || (pattern == "::/0"))
|
||||
any_host = false;
|
||||
else
|
||||
boost::range::remove_erase(like_patterns, pattern);
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::addLocalHost()
|
||||
{
|
||||
local_host = true;
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::removeLocalHost()
|
||||
{
|
||||
local_host = false;
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::addAnyHost()
|
||||
{
|
||||
clear();
|
||||
any_host = true;
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::add(const AllowedClientHosts & other)
|
||||
{
|
||||
if (other.containsAnyHost())
|
||||
{
|
||||
addAnyHost();
|
||||
return;
|
||||
}
|
||||
if (other.containsLocalHost())
|
||||
addLocalHost();
|
||||
for (const IPAddress & address : other.getAddresses())
|
||||
addAddress(address);
|
||||
for (const IPSubnet & subnet : other.getSubnets())
|
||||
addSubnet(subnet);
|
||||
for (const String & name : other.getNames())
|
||||
addName(name);
|
||||
for (const String & name_regexp : other.getNameRegexps())
|
||||
addNameRegexp(name_regexp);
|
||||
for (const String & like_pattern : other.getLikePatterns())
|
||||
addLikePattern(like_pattern);
|
||||
}
|
||||
|
||||
inline void AllowedClientHosts::remove(const AllowedClientHosts & other)
|
||||
{
|
||||
if (other.containsAnyHost())
|
||||
{
|
||||
clear();
|
||||
return;
|
||||
}
|
||||
if (other.containsLocalHost())
|
||||
removeLocalHost();
|
||||
for (const IPAddress & address : other.getAddresses())
|
||||
removeAddress(address);
|
||||
for (const IPSubnet & subnet : other.getSubnets())
|
||||
removeSubnet(subnet);
|
||||
for (const String & name : other.getNames())
|
||||
removeName(name);
|
||||
for (const String & name_regexp : other.getNameRegexps())
|
||||
removeNameRegexp(name_regexp);
|
||||
for (const String & like_pattern : other.getLikePatterns())
|
||||
removeLikePattern(like_pattern);
|
||||
}
|
||||
|
||||
|
||||
inline bool operator ==(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs)
|
||||
{
|
||||
return (lhs.any_host == rhs.any_host) && (lhs.local_host == rhs.local_host) && (lhs.addresses == rhs.addresses)
|
||||
&& (lhs.subnets == rhs.subnets) && (lhs.names == rhs.names) && (lhs.name_regexps == rhs.name_regexps)
|
||||
&& (lhs.like_patterns == rhs.like_patterns);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,166 +1,18 @@
|
||||
#include <Access/Authentication.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <common/StringRef.h>
|
||||
#include <Core/Defines.h>
|
||||
#include <Poco/SHA1Engine.h>
|
||||
#include <boost/algorithm/hex.hpp>
|
||||
#include "config_core.h"
|
||||
#if USE_SSL
|
||||
# include <openssl/sha.h>
|
||||
#endif
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int SUPPORT_IS_DISABLED;
|
||||
extern const int REQUIRED_PASSWORD;
|
||||
extern const int WRONG_PASSWORD;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
|
||||
namespace
|
||||
{
|
||||
using Digest = Authentication::Digest;
|
||||
|
||||
Digest encodePlainText(const StringRef & text)
|
||||
{
|
||||
return Digest(text.data, text.data + text.size);
|
||||
}
|
||||
|
||||
Digest encodeSHA256(const StringRef & text)
|
||||
{
|
||||
#if USE_SSL
|
||||
Digest hash;
|
||||
hash.resize(32);
|
||||
SHA256_CTX ctx;
|
||||
SHA256_Init(&ctx);
|
||||
SHA256_Update(&ctx, reinterpret_cast<const UInt8 *>(text.data), text.size);
|
||||
SHA256_Final(hash.data(), &ctx);
|
||||
return hash;
|
||||
#else
|
||||
UNUSED(text);
|
||||
throw DB::Exception("SHA256 passwords support is disabled, because ClickHouse was built without SSL library", DB::ErrorCodes::SUPPORT_IS_DISABLED);
|
||||
#endif
|
||||
}
|
||||
|
||||
Digest encodeSHA1(const StringRef & text)
|
||||
{
|
||||
Poco::SHA1Engine engine;
|
||||
engine.update(text.data, text.size);
|
||||
return engine.digest();
|
||||
}
|
||||
|
||||
Digest encodeSHA1(const Digest & text)
|
||||
{
|
||||
return encodeSHA1(StringRef{reinterpret_cast<const char *>(text.data()), text.size()});
|
||||
}
|
||||
|
||||
Digest encodeDoubleSHA1(const StringRef & text)
|
||||
{
|
||||
return encodeSHA1(encodeSHA1(text));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Authentication::Authentication(Authentication::Type type_)
|
||||
: type(type_)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
void Authentication::setPassword(const String & password_)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case NO_PASSWORD:
|
||||
throw Exception("Cannot specify password for the 'NO_PASSWORD' authentication type", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
case PLAINTEXT_PASSWORD:
|
||||
setPasswordHashBinary(encodePlainText(password_));
|
||||
return;
|
||||
|
||||
case SHA256_PASSWORD:
|
||||
setPasswordHashBinary(encodeSHA256(password_));
|
||||
return;
|
||||
|
||||
case DOUBLE_SHA1_PASSWORD:
|
||||
setPasswordHashBinary(encodeDoubleSHA1(password_));
|
||||
return;
|
||||
}
|
||||
throw Exception("Unknown authentication type: " + std::to_string(static_cast<int>(type)), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
|
||||
String Authentication::getPassword() const
|
||||
{
|
||||
if (type != PLAINTEXT_PASSWORD)
|
||||
throw Exception("Cannot decode the password", ErrorCodes::LOGICAL_ERROR);
|
||||
return String(password_hash.data(), password_hash.data() + password_hash.size());
|
||||
}
|
||||
|
||||
|
||||
void Authentication::setPasswordHashHex(const String & hash)
|
||||
{
|
||||
Digest digest;
|
||||
digest.resize(hash.size() / 2);
|
||||
boost::algorithm::unhex(hash.begin(), hash.end(), digest.data());
|
||||
setPasswordHashBinary(digest);
|
||||
}
|
||||
|
||||
|
||||
String Authentication::getPasswordHashHex() const
|
||||
{
|
||||
String hex;
|
||||
hex.resize(password_hash.size() * 2);
|
||||
boost::algorithm::hex(password_hash.begin(), password_hash.end(), hex.data());
|
||||
return hex;
|
||||
}
|
||||
|
||||
|
||||
void Authentication::setPasswordHashBinary(const Digest & hash)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case NO_PASSWORD:
|
||||
throw Exception("Cannot specify password for the 'NO_PASSWORD' authentication type", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
case PLAINTEXT_PASSWORD:
|
||||
{
|
||||
password_hash = hash;
|
||||
return;
|
||||
}
|
||||
|
||||
case SHA256_PASSWORD:
|
||||
{
|
||||
if (hash.size() != 32)
|
||||
throw Exception(
|
||||
"Password hash for the 'SHA256_PASSWORD' authentication type has length " + std::to_string(hash.size())
|
||||
+ " but must be exactly 32 bytes.",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
password_hash = hash;
|
||||
return;
|
||||
}
|
||||
|
||||
case DOUBLE_SHA1_PASSWORD:
|
||||
{
|
||||
if (hash.size() != 20)
|
||||
throw Exception(
|
||||
"Password hash for the 'DOUBLE_SHA1_PASSWORD' authentication type has length " + std::to_string(hash.size())
|
||||
+ " but must be exactly 20 bytes.",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
password_hash = hash;
|
||||
return;
|
||||
}
|
||||
}
|
||||
throw Exception("Unknown authentication type: " + std::to_string(static_cast<int>(type)), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
|
||||
Digest Authentication::getPasswordDoubleSHA1() const
|
||||
Authentication::Digest Authentication::getPasswordDoubleSHA1() const
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
@ -198,12 +50,12 @@ bool Authentication::isCorrectPassword(const String & password_) const
|
||||
|
||||
case PLAINTEXT_PASSWORD:
|
||||
{
|
||||
if (password_ == StringRef{reinterpret_cast<const char *>(password_hash.data()), password_hash.size()})
|
||||
if (password_ == std::string_view{reinterpret_cast<const char *>(password_hash.data()), password_hash.size()})
|
||||
return true;
|
||||
|
||||
// For compatibility with MySQL clients which support only native authentication plugin, SHA1 can be passed instead of password.
|
||||
auto password_sha1 = encodeSHA1(password_hash);
|
||||
return password_ == StringRef{reinterpret_cast<const char *>(password_sha1.data()), password_sha1.size()};
|
||||
return password_ == std::string_view{reinterpret_cast<const char *>(password_sha1.data()), password_sha1.size()};
|
||||
}
|
||||
|
||||
case SHA256_PASSWORD:
|
||||
@ -234,10 +86,5 @@ void Authentication::checkPassword(const String & password_, const String & user
|
||||
throw Exception("Wrong password" + info_about_user_name(), ErrorCodes::WRONG_PASSWORD);
|
||||
}
|
||||
|
||||
|
||||
bool operator ==(const Authentication & lhs, const Authentication & rhs)
|
||||
{
|
||||
return (lhs.type == rhs.type) && (lhs.password_hash == rhs.password_hash);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include <Core/Types.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/OpenSSLHelpers.h>
|
||||
#include <Poco/SHA1Engine.h>
|
||||
#include <boost/algorithm/hex.hpp>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int SUPPORT_IS_DISABLED;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
|
||||
/// Authentication type and encrypted password for checking when an user logins.
|
||||
class Authentication
|
||||
{
|
||||
@ -27,7 +39,7 @@ public:
|
||||
|
||||
using Digest = std::vector<UInt8>;
|
||||
|
||||
Authentication(Authentication::Type type = NO_PASSWORD);
|
||||
Authentication(Authentication::Type type_ = NO_PASSWORD) : type(type_) {}
|
||||
Authentication(const Authentication & src) = default;
|
||||
Authentication & operator =(const Authentication & src) = default;
|
||||
Authentication(Authentication && src) = default;
|
||||
@ -36,17 +48,19 @@ public:
|
||||
Type getType() const { return type; }
|
||||
|
||||
/// Sets the password and encrypt it using the authentication type set in the constructor.
|
||||
void setPassword(const String & password);
|
||||
void setPassword(const String & password_);
|
||||
|
||||
/// Returns the password. Allowed to use only for Type::PLAINTEXT_PASSWORD.
|
||||
String getPassword() const;
|
||||
|
||||
/// Sets the password as a string of hexadecimal digits.
|
||||
void setPasswordHashHex(const String & hash);
|
||||
|
||||
String getPasswordHashHex() const;
|
||||
|
||||
/// Sets the password in binary form.
|
||||
void setPasswordHashBinary(const Digest & hash);
|
||||
|
||||
const Digest & getPasswordHashBinary() const { return password_hash; }
|
||||
|
||||
/// Returns SHA1(SHA1(password)) used by MySQL compatibility server for authentication.
|
||||
@ -60,11 +74,124 @@ public:
|
||||
/// `user_name` is only used for generating an error message if the password is incorrect.
|
||||
void checkPassword(const String & password, const String & user_name = String()) const;
|
||||
|
||||
friend bool operator ==(const Authentication & lhs, const Authentication & rhs);
|
||||
friend bool operator ==(const Authentication & lhs, const Authentication & rhs) { return (lhs.type == rhs.type) && (lhs.password_hash == rhs.password_hash); }
|
||||
friend bool operator !=(const Authentication & lhs, const Authentication & rhs) { return !(lhs == rhs); }
|
||||
|
||||
private:
|
||||
static Digest encodePlainText(const std::string_view & text) { return Digest(text.data(), text.data() + text.size()); }
|
||||
static Digest encodeSHA256(const std::string_view & text);
|
||||
static Digest encodeSHA1(const std::string_view & text);
|
||||
static Digest encodeSHA1(const Digest & text) { return encodeSHA1(std::string_view{reinterpret_cast<const char *>(text.data()), text.size()}); }
|
||||
static Digest encodeDoubleSHA1(const std::string_view & text) { return encodeSHA1(encodeSHA1(text)); }
|
||||
|
||||
Type type = Type::NO_PASSWORD;
|
||||
Digest password_hash;
|
||||
};
|
||||
|
||||
|
||||
inline Authentication::Digest Authentication::encodeSHA256(const std::string_view & text [[maybe_unused]])
|
||||
{
|
||||
#if USE_SSL
|
||||
Digest hash;
|
||||
hash.resize(32);
|
||||
::DB::encodeSHA256(text, hash.data());
|
||||
return hash;
|
||||
#else
|
||||
throw DB::Exception(
|
||||
"SHA256 passwords support is disabled, because ClickHouse was built without SSL library",
|
||||
DB::ErrorCodes::SUPPORT_IS_DISABLED);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline Authentication::Digest Authentication::encodeSHA1(const std::string_view & text)
|
||||
{
|
||||
Poco::SHA1Engine engine;
|
||||
engine.update(text.data(), text.size());
|
||||
return engine.digest();
|
||||
}
|
||||
|
||||
|
||||
inline void Authentication::setPassword(const String & password_)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case NO_PASSWORD:
|
||||
throw Exception("Cannot specify password for the 'NO_PASSWORD' authentication type", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
case PLAINTEXT_PASSWORD:
|
||||
return setPasswordHashBinary(encodePlainText(password_));
|
||||
|
||||
case SHA256_PASSWORD:
|
||||
return setPasswordHashBinary(encodeSHA256(password_));
|
||||
|
||||
case DOUBLE_SHA1_PASSWORD:
|
||||
return setPasswordHashBinary(encodeDoubleSHA1(password_));
|
||||
}
|
||||
throw Exception("Unknown authentication type: " + std::to_string(static_cast<int>(type)), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
|
||||
inline String Authentication::getPassword() const
|
||||
{
|
||||
if (type != PLAINTEXT_PASSWORD)
|
||||
throw Exception("Cannot decode the password", ErrorCodes::LOGICAL_ERROR);
|
||||
return String(password_hash.data(), password_hash.data() + password_hash.size());
|
||||
}
|
||||
|
||||
|
||||
inline void Authentication::setPasswordHashHex(const String & hash)
|
||||
{
|
||||
Digest digest;
|
||||
digest.resize(hash.size() / 2);
|
||||
boost::algorithm::unhex(hash.begin(), hash.end(), digest.data());
|
||||
setPasswordHashBinary(digest);
|
||||
}
|
||||
|
||||
inline String Authentication::getPasswordHashHex() const
|
||||
{
|
||||
String hex;
|
||||
hex.resize(password_hash.size() * 2);
|
||||
boost::algorithm::hex(password_hash.begin(), password_hash.end(), hex.data());
|
||||
return hex;
|
||||
}
|
||||
|
||||
|
||||
inline void Authentication::setPasswordHashBinary(const Digest & hash)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case NO_PASSWORD:
|
||||
throw Exception("Cannot specify password for the 'NO_PASSWORD' authentication type", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
case PLAINTEXT_PASSWORD:
|
||||
{
|
||||
password_hash = hash;
|
||||
return;
|
||||
}
|
||||
|
||||
case SHA256_PASSWORD:
|
||||
{
|
||||
if (hash.size() != 32)
|
||||
throw Exception(
|
||||
"Password hash for the 'SHA256_PASSWORD' authentication type has length " + std::to_string(hash.size())
|
||||
+ " but must be exactly 32 bytes.",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
password_hash = hash;
|
||||
return;
|
||||
}
|
||||
|
||||
case DOUBLE_SHA1_PASSWORD:
|
||||
{
|
||||
if (hash.size() != 20)
|
||||
throw Exception(
|
||||
"Password hash for the 'DOUBLE_SHA1_PASSWORD' authentication type has length " + std::to_string(hash.size())
|
||||
+ " but must be exactly 20 bytes.",
|
||||
ErrorCodes::BAD_ARGUMENTS);
|
||||
password_hash = hash;
|
||||
return;
|
||||
}
|
||||
}
|
||||
throw Exception("Unknown authentication type: " + std::to_string(static_cast<int>(type)), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -10,7 +10,8 @@ bool User::equal(const IAccessEntity & other) const
|
||||
return false;
|
||||
const auto & other_user = typeid_cast<const User &>(other);
|
||||
return (authentication == other_user.authentication) && (allowed_client_hosts == other_user.allowed_client_hosts)
|
||||
&& (access == other_user.access) && (profile == other_user.profile);
|
||||
&& (access == other_user.access) && (access_with_grant_option == other_user.access_with_grant_option)
|
||||
&& (profile == other_user.profile);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -14,8 +14,9 @@ namespace DB
|
||||
struct User : public IAccessEntity
|
||||
{
|
||||
Authentication authentication;
|
||||
AllowedClientHosts allowed_client_hosts;
|
||||
AllowedClientHosts allowed_client_hosts{AllowedClientHosts::AnyHostTag{}};
|
||||
AccessRights access;
|
||||
AccessRights access_with_grant_option;
|
||||
String profile;
|
||||
|
||||
bool equal(const IAccessEntity & other) const override;
|
||||
|
@ -90,15 +90,16 @@ namespace
|
||||
{
|
||||
Poco::Util::AbstractConfiguration::Keys keys;
|
||||
config.keys(networks_config, keys);
|
||||
user->allowed_client_hosts.clear();
|
||||
for (const String & key : keys)
|
||||
{
|
||||
String value = config.getString(networks_config + "." + key);
|
||||
if (key.starts_with("ip"))
|
||||
user->allowed_client_hosts.addSubnet(value);
|
||||
else if (key.starts_with("host_regexp"))
|
||||
user->allowed_client_hosts.addHostRegexp(value);
|
||||
user->allowed_client_hosts.addNameRegexp(value);
|
||||
else if (key.starts_with("host"))
|
||||
user->allowed_client_hosts.addHostName(value);
|
||||
user->allowed_client_hosts.addName(value);
|
||||
else
|
||||
throw Exception("Unknown address pattern type: " + key, ErrorCodes::UNKNOWN_ADDRESS_PATTERN_TYPE);
|
||||
}
|
||||
@ -143,7 +144,6 @@ namespace
|
||||
user->access.fullRevoke(AccessFlags::databaseLevel());
|
||||
for (const String & database : *databases)
|
||||
user->access.grant(AccessFlags::databaseLevel(), database);
|
||||
user->access.grant(AccessFlags::databaseLevel(), "system"); /// Anyone has access to the "system" database.
|
||||
}
|
||||
|
||||
if (dictionaries)
|
||||
@ -155,6 +155,8 @@ namespace
|
||||
else if (databases)
|
||||
user->access.grant(AccessType::dictGet, IDictionary::NO_DATABASE_TAG);
|
||||
|
||||
user->access_with_grant_option = user->access;
|
||||
|
||||
return user;
|
||||
}
|
||||
|
||||
|
@ -74,7 +74,8 @@ void Connection::connect(const ConnectionTimeouts & timeouts)
|
||||
|
||||
current_resolved_address = DNSResolver::instance().resolveAddress(host, port);
|
||||
|
||||
socket->connect(*current_resolved_address, timeouts.connection_timeout);
|
||||
const auto & connection_timeout = static_cast<bool>(secure) ? timeouts.secure_connection_timeout : timeouts.connection_timeout;
|
||||
socket->connect(*current_resolved_address, connection_timeout);
|
||||
socket->setReceiveTimeout(timeouts.receive_timeout);
|
||||
socket->setSendTimeout(timeouts.send_timeout);
|
||||
socket->setNoDelay(true);
|
||||
|
@ -15,7 +15,8 @@
|
||||
M(DiskSpaceReservedForMerge, "Disk space reserved for currently running background merges. It is slightly more than the total size of currently merging parts.") \
|
||||
M(DistributedSend, "Number of connections to remote servers sending data that was INSERTed into Distributed tables. Both synchronous and asynchronous mode.") \
|
||||
M(QueryPreempted, "Number of queries that are stopped and waiting due to 'priority' setting.") \
|
||||
M(TCPConnection, "Number of connections to TCP server (clients with native interface)") \
|
||||
M(TCPConnection, "Number of connections to TCP server (clients with native interface), also included server-server distributed query connections") \
|
||||
M(MySQLConnection, "Number of client connections using MySQL protocol") \
|
||||
M(HTTPConnection, "Number of connections to HTTP server") \
|
||||
M(InterserverConnection, "Number of connections from other replicas to fetch parts") \
|
||||
M(OpenFileForRead, "Number of files open for reading") \
|
||||
|
@ -481,6 +481,7 @@ namespace ErrorCodes
|
||||
extern const int UNABLE_TO_SKIP_UNUSED_SHARDS = 507;
|
||||
extern const int UNKNOWN_ACCESS_TYPE = 508;
|
||||
extern const int INVALID_GRANT = 509;
|
||||
extern const int CACHE_DICTIONARY_UPDATE_FAIL = 510;
|
||||
|
||||
extern const int KEEPER_EXCEPTION = 999;
|
||||
extern const int POCO_EXCEPTION = 1000;
|
||||
|
@ -1,12 +1,17 @@
|
||||
#include <cstdlib>
|
||||
|
||||
#include "MemoryTracker.h"
|
||||
#include <common/likely.h>
|
||||
#include <common/logger_useful.h>
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include "Common/TraceCollector.h"
|
||||
#include <Common/CurrentThread.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/formatReadable.h>
|
||||
#include <Common/CurrentThread.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <common/likely.h>
|
||||
#include <common/logger_useful.h>
|
||||
#include <ext/singleton.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -73,7 +78,7 @@ void MemoryTracker::alloc(Int64 size)
|
||||
return;
|
||||
|
||||
/** Using memory_order_relaxed means that if allocations are done simultaneously,
|
||||
* we allow exception about memory limit exceeded to be thrown only on next allocation.
|
||||
* we allow exception about memory limit exceeded to be thrown only on next allocation.
|
||||
* So, we allow over-allocations.
|
||||
*/
|
||||
Int64 will_be = size + amount.fetch_add(size, std::memory_order_relaxed);
|
||||
@ -81,7 +86,8 @@ void MemoryTracker::alloc(Int64 size)
|
||||
if (metric != CurrentMetrics::end())
|
||||
CurrentMetrics::add(metric, size);
|
||||
|
||||
Int64 current_limit = limit.load(std::memory_order_relaxed);
|
||||
Int64 current_hard_limit = hard_limit.load(std::memory_order_relaxed);
|
||||
Int64 current_profiler_limit = profiler_limit.load(std::memory_order_relaxed);
|
||||
|
||||
/// Using non-thread-safe random number generator. Joint distribution in different threads would not be uniform.
|
||||
/// In this case, it doesn't matter.
|
||||
@ -98,12 +104,19 @@ void MemoryTracker::alloc(Int64 size)
|
||||
message << " " << description;
|
||||
message << ": fault injected. Would use " << formatReadableSizeWithBinarySuffix(will_be)
|
||||
<< " (attempt to allocate chunk of " << size << " bytes)"
|
||||
<< ", maximum: " << formatReadableSizeWithBinarySuffix(current_limit);
|
||||
<< ", maximum: " << formatReadableSizeWithBinarySuffix(current_hard_limit);
|
||||
|
||||
throw DB::Exception(message.str(), DB::ErrorCodes::MEMORY_LIMIT_EXCEEDED);
|
||||
}
|
||||
|
||||
if (unlikely(current_limit && will_be > current_limit))
|
||||
if (unlikely(current_profiler_limit && will_be > current_profiler_limit))
|
||||
{
|
||||
auto no_track = blocker.cancel();
|
||||
ext::Singleton<DB::TraceCollector>()->collect(size);
|
||||
setOrRaiseProfilerLimit(current_profiler_limit + Int64(std::ceil((will_be - current_profiler_limit) / profiler_step)) * profiler_step);
|
||||
}
|
||||
|
||||
if (unlikely(current_hard_limit && will_be > current_hard_limit))
|
||||
{
|
||||
free(size);
|
||||
|
||||
@ -116,7 +129,7 @@ void MemoryTracker::alloc(Int64 size)
|
||||
message << " " << description;
|
||||
message << " exceeded: would use " << formatReadableSizeWithBinarySuffix(will_be)
|
||||
<< " (attempt to allocate chunk of " << size << " bytes)"
|
||||
<< ", maximum: " << formatReadableSizeWithBinarySuffix(current_limit);
|
||||
<< ", maximum: " << formatReadableSizeWithBinarySuffix(current_hard_limit);
|
||||
|
||||
throw DB::Exception(message.str(), DB::ErrorCodes::MEMORY_LIMIT_EXCEEDED);
|
||||
}
|
||||
@ -174,7 +187,8 @@ void MemoryTracker::resetCounters()
|
||||
{
|
||||
amount.store(0, std::memory_order_relaxed);
|
||||
peak.store(0, std::memory_order_relaxed);
|
||||
limit.store(0, std::memory_order_relaxed);
|
||||
hard_limit.store(0, std::memory_order_relaxed);
|
||||
profiler_limit.store(0, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
|
||||
@ -187,11 +201,20 @@ void MemoryTracker::reset()
|
||||
}
|
||||
|
||||
|
||||
void MemoryTracker::setOrRaiseLimit(Int64 value)
|
||||
void MemoryTracker::setOrRaiseHardLimit(Int64 value)
|
||||
{
|
||||
/// This is just atomic set to maximum.
|
||||
Int64 old_value = limit.load(std::memory_order_relaxed);
|
||||
while (old_value < value && !limit.compare_exchange_weak(old_value, value))
|
||||
Int64 old_value = hard_limit.load(std::memory_order_relaxed);
|
||||
while (old_value < value && !hard_limit.compare_exchange_weak(old_value, value))
|
||||
;
|
||||
}
|
||||
|
||||
|
||||
void MemoryTracker::setOrRaiseProfilerLimit(Int64 value)
|
||||
{
|
||||
/// This is just atomic set to maximum.
|
||||
Int64 old_value = profiler_limit.load(std::memory_order_relaxed);
|
||||
while (old_value < value && !profiler_limit.compare_exchange_weak(old_value, value))
|
||||
;
|
||||
}
|
||||
|
||||
@ -207,7 +230,7 @@ namespace CurrentMemoryTracker
|
||||
if (untracked > untracked_memory_limit)
|
||||
{
|
||||
/// Zero untracked before track. If tracker throws out-of-limit we would be able to alloc up to untracked_memory_limit bytes
|
||||
/// more. It could be usefull for enlarge Exception message in rethrow logic.
|
||||
/// more. It could be useful to enlarge Exception message in rethrow logic.
|
||||
Int64 tmp = untracked;
|
||||
untracked = 0;
|
||||
memory_tracker->alloc(tmp);
|
||||
@ -218,10 +241,7 @@ namespace CurrentMemoryTracker
|
||||
void realloc(Int64 old_size, Int64 new_size)
|
||||
{
|
||||
Int64 addition = new_size - old_size;
|
||||
if (addition > 0)
|
||||
alloc(addition);
|
||||
else
|
||||
free(-addition);
|
||||
addition > 0 ? alloc(addition) : free(-addition);
|
||||
}
|
||||
|
||||
void free(Int64 size)
|
||||
|
@ -15,7 +15,10 @@ class MemoryTracker
|
||||
{
|
||||
std::atomic<Int64> amount {0};
|
||||
std::atomic<Int64> peak {0};
|
||||
std::atomic<Int64> limit {0};
|
||||
std::atomic<Int64> hard_limit {0};
|
||||
std::atomic<Int64> profiler_limit {0};
|
||||
|
||||
Int64 profiler_step = 0;
|
||||
|
||||
/// To test exception safety of calling code, memory tracker throws an exception on each memory allocation with specified probability.
|
||||
double fault_probability = 0;
|
||||
@ -32,7 +35,6 @@ class MemoryTracker
|
||||
|
||||
public:
|
||||
MemoryTracker(VariableContext level_ = VariableContext::Thread) : level(level_) {}
|
||||
MemoryTracker(Int64 limit_, VariableContext level_ = VariableContext::Thread) : limit(limit_), level(level_) {}
|
||||
MemoryTracker(MemoryTracker * parent_, VariableContext level_ = VariableContext::Thread) : parent(parent_), level(level_) {}
|
||||
|
||||
~MemoryTracker();
|
||||
@ -66,21 +68,22 @@ public:
|
||||
return peak.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
void setLimit(Int64 limit_)
|
||||
{
|
||||
limit.store(limit_, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
/** Set limit if it was not set.
|
||||
* Otherwise, set limit to new value, if new value is greater than previous limit.
|
||||
*/
|
||||
void setOrRaiseLimit(Int64 value);
|
||||
void setOrRaiseHardLimit(Int64 value);
|
||||
void setOrRaiseProfilerLimit(Int64 value);
|
||||
|
||||
void setFaultProbability(double value)
|
||||
{
|
||||
fault_probability = value;
|
||||
}
|
||||
|
||||
void setProfilerStep(Int64 value)
|
||||
{
|
||||
profiler_step = value;
|
||||
}
|
||||
|
||||
/// next should be changed only once: from nullptr to some value.
|
||||
/// NOTE: It is not true in MergeListElement
|
||||
void setParent(MemoryTracker * elem)
|
||||
|
@ -3,11 +3,20 @@
|
||||
#include "OpenSSLHelpers.h"
|
||||
#include <ext/scope_guard.h>
|
||||
#include <openssl/err.h>
|
||||
#include <openssl/sha.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
#pragma GCC diagnostic warning "-Wold-style-cast"
|
||||
|
||||
void encodeSHA256(const std::string_view & text, unsigned char * out)
|
||||
{
|
||||
SHA256_CTX ctx;
|
||||
SHA256_Init(&ctx);
|
||||
SHA256_Update(&ctx, reinterpret_cast<const UInt8 *>(text.data()), text.size());
|
||||
SHA256_Final(out, &ctx);
|
||||
}
|
||||
|
||||
String getOpenSSLErrors()
|
||||
{
|
||||
BIO * mem = BIO_new(BIO_s_mem());
|
||||
|
@ -7,6 +7,8 @@
|
||||
|
||||
namespace DB
|
||||
{
|
||||
/// Encodes `text` and puts the result to `out` which must be at least 32 bytes long.
|
||||
void encodeSHA256(const std::string_view & text, unsigned char * out);
|
||||
|
||||
/// Returns concatenation of error strings for all errors that OpenSSL has recorded, emptying the error queue.
|
||||
String getOpenSSLErrors();
|
||||
|
@ -1,92 +1,38 @@
|
||||
#include "QueryProfiler.h"
|
||||
|
||||
#include <random>
|
||||
#include <common/phdr_cache.h>
|
||||
#include <common/config_common.h>
|
||||
#include <common/StringRef.h>
|
||||
#include <common/logger_useful.h>
|
||||
#include <Common/PipeFDs.h>
|
||||
#include <Common/StackTrace.h>
|
||||
#include <Common/CurrentThread.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/thread_local_rng.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/WriteBufferFromFileDescriptorDiscardOnFailure.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/StackTrace.h>
|
||||
#include <Common/TraceCollector.h>
|
||||
#include <Common/thread_local_rng.h>
|
||||
#include <common/StringRef.h>
|
||||
#include <common/config_common.h>
|
||||
#include <common/logger_useful.h>
|
||||
#include <common/phdr_cache.h>
|
||||
#include <ext/singleton.h>
|
||||
|
||||
#include <random>
|
||||
|
||||
namespace ProfileEvents
|
||||
{
|
||||
extern const Event QueryProfilerSignalOverruns;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
extern LazyPipeFDs trace_pipe;
|
||||
|
||||
namespace
|
||||
{
|
||||
/// Normally query_id is a UUID (string with a fixed length) but user can provide custom query_id.
|
||||
/// Thus upper bound on query_id length should be introduced to avoid buffer overflow in signal handler.
|
||||
constexpr size_t QUERY_ID_MAX_LEN = 1024;
|
||||
|
||||
#if defined(OS_LINUX)
|
||||
thread_local size_t write_trace_iteration = 0;
|
||||
#endif
|
||||
|
||||
void writeTraceInfo(TimerType timer_type, int /* sig */, siginfo_t * info, void * context)
|
||||
void writeTraceInfo(TraceType trace_type, int /* sig */, siginfo_t * info, void * context)
|
||||
{
|
||||
int overrun_count = 0;
|
||||
#if defined(OS_LINUX)
|
||||
/// Quickly drop if signal handler is called too frequently.
|
||||
/// Otherwise we may end up infinitelly processing signals instead of doing any useful work.
|
||||
++write_trace_iteration;
|
||||
if (info && info->si_overrun > 0)
|
||||
{
|
||||
/// But pass with some frequency to avoid drop of all traces.
|
||||
if (write_trace_iteration % info->si_overrun == 0)
|
||||
{
|
||||
ProfileEvents::increment(ProfileEvents::QueryProfilerSignalOverruns, info->si_overrun);
|
||||
}
|
||||
else
|
||||
{
|
||||
ProfileEvents::increment(ProfileEvents::QueryProfilerSignalOverruns, info->si_overrun + 1);
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (info)
|
||||
overrun_count = info->si_overrun;
|
||||
#else
|
||||
UNUSED(info);
|
||||
#endif
|
||||
|
||||
constexpr size_t buf_size = sizeof(char) + // TraceCollector stop flag
|
||||
8 * sizeof(char) + // maximum VarUInt length for string size
|
||||
QUERY_ID_MAX_LEN * sizeof(char) + // maximum query_id length
|
||||
sizeof(UInt8) + // number of stack frames
|
||||
sizeof(StackTrace::Frames) + // collected stack trace, maximum capacity
|
||||
sizeof(TimerType) + // timer type
|
||||
sizeof(UInt64); // thread_id
|
||||
char buffer[buf_size];
|
||||
WriteBufferFromFileDescriptorDiscardOnFailure out(trace_pipe.fds_rw[1], buf_size, buffer);
|
||||
|
||||
StringRef query_id = CurrentThread::getQueryId();
|
||||
query_id.size = std::min(query_id.size, QUERY_ID_MAX_LEN);
|
||||
|
||||
UInt64 thread_id = CurrentThread::get().thread_id;
|
||||
|
||||
const auto signal_context = *reinterpret_cast<ucontext_t *>(context);
|
||||
const StackTrace stack_trace(signal_context);
|
||||
|
||||
writeChar(false, out);
|
||||
writeStringBinary(query_id, out);
|
||||
|
||||
size_t stack_trace_size = stack_trace.getSize();
|
||||
size_t stack_trace_offset = stack_trace.getOffset();
|
||||
writeIntBinary(UInt8(stack_trace_size - stack_trace_offset), out);
|
||||
for (size_t i = stack_trace_offset; i < stack_trace_size; ++i)
|
||||
writePODBinary(stack_trace.getFrames()[i], out);
|
||||
|
||||
writePODBinary(timer_type, out);
|
||||
writePODBinary(thread_id, out);
|
||||
out.next();
|
||||
ext::Singleton<TraceCollector>()->collect(trace_type, stack_trace, overrun_count);
|
||||
}
|
||||
|
||||
[[maybe_unused]] const UInt32 TIMER_PRECISION = 1e9;
|
||||
@ -135,11 +81,11 @@ QueryProfilerBase<ProfilerImpl>::QueryProfilerBase(const UInt64 thread_id, const
|
||||
sev.sigev_notify = SIGEV_THREAD_ID;
|
||||
sev.sigev_signo = pause_signal;
|
||||
|
||||
#if defined(__FreeBSD__)
|
||||
# if defined(__FreeBSD__)
|
||||
sev._sigev_un._threadid = thread_id;
|
||||
#else
|
||||
# else
|
||||
sev._sigev_un._tid = thread_id;
|
||||
#endif
|
||||
# endif
|
||||
if (timer_create(clock_type, &sev, &timer_id))
|
||||
{
|
||||
/// In Google Cloud Run, the function "timer_create" is implemented incorrectly as of 2020-01-25.
|
||||
@ -206,7 +152,7 @@ QueryProfilerReal::QueryProfilerReal(const UInt64 thread_id, const UInt32 period
|
||||
|
||||
void QueryProfilerReal::signalHandler(int sig, siginfo_t * info, void * context)
|
||||
{
|
||||
writeTraceInfo(TimerType::Real, sig, info, context);
|
||||
writeTraceInfo(TraceType::REAL_TIME, sig, info, context);
|
||||
}
|
||||
|
||||
QueryProfilerCpu::QueryProfilerCpu(const UInt64 thread_id, const UInt32 period)
|
||||
@ -215,7 +161,7 @@ QueryProfilerCpu::QueryProfilerCpu(const UInt64 thread_id, const UInt32 period)
|
||||
|
||||
void QueryProfilerCpu::signalHandler(int sig, siginfo_t * info, void * context)
|
||||
{
|
||||
writeTraceInfo(TimerType::Cpu, sig, info, context);
|
||||
writeTraceInfo(TraceType::CPU_TIME, sig, info, context);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -15,12 +15,6 @@ namespace Poco
|
||||
namespace DB
|
||||
{
|
||||
|
||||
enum class TimerType : UInt8
|
||||
{
|
||||
Real,
|
||||
Cpu,
|
||||
};
|
||||
|
||||
/**
|
||||
* Query profiler implementation for selected thread.
|
||||
*
|
||||
|
@ -1,25 +1,38 @@
|
||||
#include "TraceCollector.h"
|
||||
|
||||
#include <Core/Field.h>
|
||||
#include <IO/ReadBufferFromFileDescriptor.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteBufferFromFileDescriptor.h>
|
||||
#include <IO/WriteBufferFromFileDescriptorDiscardOnFailure.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Interpreters/TraceLog.h>
|
||||
#include <Poco/Logger.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/PipeFDs.h>
|
||||
#include <Common/StackTrace.h>
|
||||
#include <common/logger_useful.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/ReadBufferFromFileDescriptor.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/WriteBufferFromFileDescriptor.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Interpreters/TraceLog.h>
|
||||
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
|
||||
|
||||
namespace ProfileEvents
|
||||
{
|
||||
extern const Event QueryProfilerSignalOverruns;
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
LazyPipeFDs trace_pipe;
|
||||
namespace
|
||||
{
|
||||
/// Normally query_id is a UUID (string with a fixed length) but user can provide custom query_id.
|
||||
/// Thus upper bound on query_id length should be introduced to avoid buffer overflow in signal handler.
|
||||
constexpr size_t QUERY_ID_MAX_LEN = 1024;
|
||||
|
||||
thread_local size_t write_trace_iteration = 0;
|
||||
}
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
@ -27,20 +40,15 @@ namespace ErrorCodes
|
||||
extern const int THREAD_IS_NOT_JOINABLE;
|
||||
}
|
||||
|
||||
TraceCollector::TraceCollector(std::shared_ptr<TraceLog> & trace_log_)
|
||||
: log(&Poco::Logger::get("TraceCollector"))
|
||||
, trace_log(trace_log_)
|
||||
TraceCollector::TraceCollector()
|
||||
{
|
||||
if (trace_log == nullptr)
|
||||
throw Exception("Invalid trace log pointer passed", ErrorCodes::NULL_POINTER_DEREFERENCE);
|
||||
|
||||
trace_pipe.open();
|
||||
pipe.open();
|
||||
|
||||
/** Turn write end of pipe to non-blocking mode to avoid deadlocks
|
||||
* when QueryProfiler is invoked under locks and TraceCollector cannot pull data from pipe.
|
||||
*/
|
||||
trace_pipe.setNonBlocking();
|
||||
trace_pipe.tryIncreaseSize(1 << 20);
|
||||
pipe.setNonBlocking();
|
||||
pipe.tryIncreaseSize(1 << 20);
|
||||
|
||||
thread = ThreadFromGlobalPool(&TraceCollector::run, this);
|
||||
}
|
||||
@ -48,14 +56,101 @@ TraceCollector::TraceCollector(std::shared_ptr<TraceLog> & trace_log_)
|
||||
TraceCollector::~TraceCollector()
|
||||
{
|
||||
if (!thread.joinable())
|
||||
LOG_ERROR(log, "TraceCollector thread is malformed and cannot be joined");
|
||||
LOG_ERROR(&Poco::Logger::get("TraceCollector"), "TraceCollector thread is malformed and cannot be joined");
|
||||
else
|
||||
{
|
||||
TraceCollector::notifyToStop();
|
||||
stop();
|
||||
thread.join();
|
||||
}
|
||||
|
||||
trace_pipe.close();
|
||||
pipe.close();
|
||||
}
|
||||
|
||||
void TraceCollector::collect(TraceType trace_type, const StackTrace & stack_trace, int overrun_count)
|
||||
{
|
||||
/// Quickly drop if signal handler is called too frequently.
|
||||
/// Otherwise we may end up infinitelly processing signals instead of doing any useful work.
|
||||
++write_trace_iteration;
|
||||
if (overrun_count)
|
||||
{
|
||||
/// But pass with some frequency to avoid drop of all traces.
|
||||
if (write_trace_iteration % overrun_count == 0)
|
||||
{
|
||||
ProfileEvents::increment(ProfileEvents::QueryProfilerSignalOverruns, overrun_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
ProfileEvents::increment(ProfileEvents::QueryProfilerSignalOverruns, overrun_count + 1);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr size_t buf_size = sizeof(char) + // TraceCollector stop flag
|
||||
8 * sizeof(char) + // maximum VarUInt length for string size
|
||||
QUERY_ID_MAX_LEN * sizeof(char) + // maximum query_id length
|
||||
sizeof(UInt8) + // number of stack frames
|
||||
sizeof(StackTrace::Frames) + // collected stack trace, maximum capacity
|
||||
sizeof(TraceType) + // trace type
|
||||
sizeof(UInt64) + // thread_id
|
||||
sizeof(UInt64); // size
|
||||
char buffer[buf_size];
|
||||
WriteBufferFromFileDescriptorDiscardOnFailure out(pipe.fds_rw[1], buf_size, buffer);
|
||||
|
||||
StringRef query_id = CurrentThread::getQueryId();
|
||||
query_id.size = std::min(query_id.size, QUERY_ID_MAX_LEN);
|
||||
|
||||
auto thread_id = CurrentThread::get().thread_id;
|
||||
|
||||
writeChar(false, out);
|
||||
writeStringBinary(query_id, out);
|
||||
|
||||
size_t stack_trace_size = stack_trace.getSize();
|
||||
size_t stack_trace_offset = stack_trace.getOffset();
|
||||
writeIntBinary(UInt8(stack_trace_size - stack_trace_offset), out);
|
||||
for (size_t i = stack_trace_offset; i < stack_trace_size; ++i)
|
||||
writePODBinary(stack_trace.getFrames()[i], out);
|
||||
|
||||
writePODBinary(trace_type, out);
|
||||
writePODBinary(thread_id, out);
|
||||
writePODBinary(UInt64(0), out);
|
||||
|
||||
out.next();
|
||||
}
|
||||
|
||||
void TraceCollector::collect(UInt64 size)
|
||||
{
|
||||
constexpr size_t buf_size = sizeof(char) + // TraceCollector stop flag
|
||||
8 * sizeof(char) + // maximum VarUInt length for string size
|
||||
QUERY_ID_MAX_LEN * sizeof(char) + // maximum query_id length
|
||||
sizeof(UInt8) + // number of stack frames
|
||||
sizeof(StackTrace::Frames) + // collected stack trace, maximum capacity
|
||||
sizeof(TraceType) + // trace type
|
||||
sizeof(UInt64) + // thread_id
|
||||
sizeof(UInt64); // size
|
||||
char buffer[buf_size];
|
||||
WriteBufferFromFileDescriptorDiscardOnFailure out(pipe.fds_rw[1], buf_size, buffer);
|
||||
|
||||
StringRef query_id = CurrentThread::getQueryId();
|
||||
query_id.size = std::min(query_id.size, QUERY_ID_MAX_LEN);
|
||||
|
||||
auto thread_id = CurrentThread::get().thread_id;
|
||||
|
||||
writeChar(false, out);
|
||||
writeStringBinary(query_id, out);
|
||||
|
||||
const auto & stack_trace = StackTrace();
|
||||
|
||||
size_t stack_trace_size = stack_trace.getSize();
|
||||
size_t stack_trace_offset = stack_trace.getOffset();
|
||||
writeIntBinary(UInt8(stack_trace_size - stack_trace_offset), out);
|
||||
for (size_t i = stack_trace_offset; i < stack_trace_size; ++i)
|
||||
writePODBinary(stack_trace.getFrames()[i], out);
|
||||
|
||||
writePODBinary(TraceType::MEMORY, out);
|
||||
writePODBinary(thread_id, out);
|
||||
writePODBinary(size, out);
|
||||
|
||||
out.next();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -68,16 +163,16 @@ TraceCollector::~TraceCollector()
|
||||
* NOTE: TraceCollector will NOT stop immediately as there may be some data left in the pipe
|
||||
* before stop message.
|
||||
*/
|
||||
void TraceCollector::notifyToStop()
|
||||
void TraceCollector::stop()
|
||||
{
|
||||
WriteBufferFromFileDescriptor out(trace_pipe.fds_rw[1]);
|
||||
WriteBufferFromFileDescriptor out(pipe.fds_rw[1]);
|
||||
writeChar(true, out);
|
||||
out.next();
|
||||
}
|
||||
|
||||
void TraceCollector::run()
|
||||
{
|
||||
ReadBufferFromFileDescriptor in(trace_pipe.fds_rw[0]);
|
||||
ReadBufferFromFileDescriptor in(pipe.fds_rw[0]);
|
||||
|
||||
while (true)
|
||||
{
|
||||
@ -89,27 +184,33 @@ void TraceCollector::run()
|
||||
std::string query_id;
|
||||
readStringBinary(query_id, in);
|
||||
|
||||
UInt8 size = 0;
|
||||
readIntBinary(size, in);
|
||||
UInt8 trace_size = 0;
|
||||
readIntBinary(trace_size, in);
|
||||
|
||||
Array trace;
|
||||
trace.reserve(size);
|
||||
trace.reserve(trace_size);
|
||||
|
||||
for (size_t i = 0; i < size; i++)
|
||||
for (size_t i = 0; i < trace_size; i++)
|
||||
{
|
||||
uintptr_t addr = 0;
|
||||
readPODBinary(addr, in);
|
||||
trace.emplace_back(UInt64(addr));
|
||||
}
|
||||
|
||||
TimerType timer_type;
|
||||
readPODBinary(timer_type, in);
|
||||
TraceType trace_type;
|
||||
readPODBinary(trace_type, in);
|
||||
|
||||
UInt64 thread_id;
|
||||
readPODBinary(thread_id, in);
|
||||
|
||||
TraceLogElement element{std::time(nullptr), timer_type, thread_id, query_id, trace};
|
||||
trace_log->add(element);
|
||||
UInt64 size;
|
||||
readPODBinary(size, in);
|
||||
|
||||
if (trace_log)
|
||||
{
|
||||
TraceLogElement element{std::time(nullptr), trace_type, thread_id, query_id, trace, size};
|
||||
trace_log->add(element);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "Common/PipeFDs.h"
|
||||
#include <Common/ThreadPool.h>
|
||||
|
||||
class StackTrace;
|
||||
|
||||
namespace Poco
|
||||
{
|
||||
class Logger;
|
||||
@ -12,21 +15,31 @@ namespace DB
|
||||
|
||||
class TraceLog;
|
||||
|
||||
enum class TraceType : UInt8
|
||||
{
|
||||
REAL_TIME,
|
||||
CPU_TIME,
|
||||
MEMORY,
|
||||
};
|
||||
|
||||
class TraceCollector
|
||||
{
|
||||
public:
|
||||
TraceCollector();
|
||||
~TraceCollector();
|
||||
|
||||
void setTraceLog(const std::shared_ptr<TraceLog> & trace_log_) { trace_log = trace_log_; }
|
||||
|
||||
void collect(TraceType type, const StackTrace & stack_trace, int overrun_count = 0);
|
||||
void collect(UInt64 size);
|
||||
|
||||
private:
|
||||
Poco::Logger * log;
|
||||
std::shared_ptr<TraceLog> trace_log;
|
||||
ThreadFromGlobalPool thread;
|
||||
LazyPipeFDs pipe;
|
||||
|
||||
void run();
|
||||
|
||||
static void notifyToStop();
|
||||
|
||||
public:
|
||||
TraceCollector(std::shared_ptr<TraceLog> & trace_log_);
|
||||
|
||||
~TraceCollector();
|
||||
void stop();
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -1,14 +1,16 @@
|
||||
#if defined(OS_LINUX)
|
||||
#include <malloc.h>
|
||||
#elif defined(OS_DARWIN)
|
||||
#include <malloc/malloc.h>
|
||||
#endif
|
||||
#include <new>
|
||||
|
||||
#include <common/config_common.h>
|
||||
#include <common/memory.h>
|
||||
#include <Common/MemoryTracker.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <new>
|
||||
|
||||
#if defined(OS_LINUX)
|
||||
# include <malloc.h>
|
||||
#elif defined(OS_DARWIN)
|
||||
# include <malloc/malloc.h>
|
||||
#endif
|
||||
|
||||
/// Replace default new/delete with memory tracking versions.
|
||||
/// @sa https://en.cppreference.com/w/cpp/memory/new/operator_new
|
||||
/// https://en.cppreference.com/w/cpp/memory/new/operator_delete
|
||||
@ -29,7 +31,7 @@ ALWAYS_INLINE void trackMemory(std::size_t size)
|
||||
#endif
|
||||
}
|
||||
|
||||
ALWAYS_INLINE bool trackMemoryNoExept(std::size_t size) noexcept
|
||||
ALWAYS_INLINE bool trackMemoryNoExcept(std::size_t size) noexcept
|
||||
{
|
||||
try
|
||||
{
|
||||
@ -54,11 +56,11 @@ ALWAYS_INLINE void untrackMemory(void * ptr [[maybe_unused]], std::size_t size [
|
||||
#else
|
||||
if (size)
|
||||
CurrentMemoryTracker::free(size);
|
||||
#ifdef _GNU_SOURCE
|
||||
# ifdef _GNU_SOURCE
|
||||
/// It's innaccurate resource free for sanitizers. malloc_usable_size() result is greater or equal to allocated size.
|
||||
else
|
||||
CurrentMemoryTracker::free(malloc_usable_size(ptr));
|
||||
#endif
|
||||
# endif
|
||||
#endif
|
||||
}
|
||||
catch (...)
|
||||
@ -83,14 +85,14 @@ void * operator new[](std::size_t size)
|
||||
|
||||
void * operator new(std::size_t size, const std::nothrow_t &) noexcept
|
||||
{
|
||||
if (likely(Memory::trackMemoryNoExept(size)))
|
||||
if (likely(Memory::trackMemoryNoExcept(size)))
|
||||
return Memory::newNoExept(size);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void * operator new[](std::size_t size, const std::nothrow_t &) noexcept
|
||||
{
|
||||
if (likely(Memory::trackMemoryNoExept(size)))
|
||||
if (likely(Memory::trackMemoryNoExcept(size)))
|
||||
return Memory::newNoExept(size);
|
||||
return nullptr;
|
||||
}
|
||||
|
17
dbms/src/Common/tests/gtest_global_context.h
Normal file
17
dbms/src/Common/tests/gtest_global_context.h
Normal file
@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#include <Interpreters/Context.h>
|
||||
|
||||
inline DB::Context createContext()
|
||||
{
|
||||
auto context = DB::Context::createGlobal();
|
||||
context.makeGlobalContext();
|
||||
context.setPath("./");
|
||||
return context;
|
||||
}
|
||||
|
||||
inline const DB::Context & getContext()
|
||||
{
|
||||
static DB::Context global_context = createContext();
|
||||
return global_context;
|
||||
}
|
@ -241,30 +241,35 @@ void decompressDataForType(const char * source, UInt32 source_size, char * dest)
|
||||
|
||||
const char * source_end = source + source_size;
|
||||
|
||||
if (source + sizeof(UInt32) > source_end)
|
||||
return;
|
||||
|
||||
const UInt32 items_count = unalignedLoad<UInt32>(source);
|
||||
source += sizeof(items_count);
|
||||
|
||||
ValueType prev_value{};
|
||||
UnsignedDeltaType prev_delta{};
|
||||
|
||||
if (source < source_end)
|
||||
{
|
||||
prev_value = unalignedLoad<ValueType>(source);
|
||||
unalignedStore<ValueType>(dest, prev_value);
|
||||
// decoding first item
|
||||
if (source + sizeof(ValueType) > source_end || items_count < 1)
|
||||
return;
|
||||
|
||||
source += sizeof(prev_value);
|
||||
dest += sizeof(prev_value);
|
||||
}
|
||||
prev_value = unalignedLoad<ValueType>(source);
|
||||
unalignedStore<ValueType>(dest, prev_value);
|
||||
|
||||
if (source < source_end)
|
||||
{
|
||||
prev_delta = unalignedLoad<UnsignedDeltaType>(source);
|
||||
prev_value = prev_value + static_cast<ValueType>(prev_delta);
|
||||
unalignedStore<ValueType>(dest, prev_value);
|
||||
source += sizeof(prev_value);
|
||||
dest += sizeof(prev_value);
|
||||
|
||||
source += sizeof(prev_delta);
|
||||
dest += sizeof(prev_value);
|
||||
}
|
||||
// decoding second item
|
||||
if (source + sizeof(UnsignedDeltaType) > source_end || items_count < 2)
|
||||
return;
|
||||
|
||||
prev_delta = unalignedLoad<UnsignedDeltaType>(source);
|
||||
prev_value = prev_value + static_cast<ValueType>(prev_delta);
|
||||
unalignedStore<ValueType>(dest, prev_value);
|
||||
|
||||
source += sizeof(prev_delta);
|
||||
dest += sizeof(prev_value);
|
||||
|
||||
BitReader reader(source, source_size - sizeof(prev_value) - sizeof(prev_delta) - sizeof(items_count));
|
||||
|
||||
|
@ -159,19 +159,23 @@ void decompressDataForType(const char * source, UInt32 source_size, char * dest)
|
||||
|
||||
const char * source_end = source + source_size;
|
||||
|
||||
if (source + sizeof(UInt32) > source_end)
|
||||
return;
|
||||
|
||||
const UInt32 items_count = unalignedLoad<UInt32>(source);
|
||||
source += sizeof(items_count);
|
||||
|
||||
T prev_value{};
|
||||
|
||||
if (source < source_end)
|
||||
{
|
||||
prev_value = unalignedLoad<T>(source);
|
||||
unalignedStore<T>(dest, prev_value);
|
||||
// decoding first item
|
||||
if (source + sizeof(T) > source_end || items_count < 1)
|
||||
return;
|
||||
|
||||
source += sizeof(prev_value);
|
||||
dest += sizeof(prev_value);
|
||||
}
|
||||
prev_value = unalignedLoad<T>(source);
|
||||
unalignedStore<T>(dest, prev_value);
|
||||
|
||||
source += sizeof(prev_value);
|
||||
dest += sizeof(prev_value);
|
||||
|
||||
BitReader reader(source, source_size - sizeof(items_count) - sizeof(prev_value));
|
||||
|
||||
|
@ -23,6 +23,78 @@ extern const int LOGICAL_ERROR;
|
||||
namespace
|
||||
{
|
||||
|
||||
/// Fixed TypeIds that numbers would not be changed between versions.
|
||||
enum class MagicNumber : uint8_t
|
||||
{
|
||||
UInt8 = 1,
|
||||
UInt16 = 2,
|
||||
UInt32 = 3,
|
||||
UInt64 = 4,
|
||||
Int8 = 6,
|
||||
Int16 = 7,
|
||||
Int32 = 8,
|
||||
Int64 = 9,
|
||||
Date = 13,
|
||||
DateTime = 14,
|
||||
DateTime64 = 15,
|
||||
Enum8 = 17,
|
||||
Enum16 = 18,
|
||||
Decimal32 = 19,
|
||||
Decimal64 = 20,
|
||||
};
|
||||
|
||||
MagicNumber serializeTypeId(TypeIndex type_id)
|
||||
{
|
||||
switch (type_id)
|
||||
{
|
||||
case TypeIndex::UInt8: return MagicNumber::UInt8;
|
||||
case TypeIndex::UInt16: return MagicNumber::UInt16;
|
||||
case TypeIndex::UInt32: return MagicNumber::UInt32;
|
||||
case TypeIndex::UInt64: return MagicNumber::UInt64;
|
||||
case TypeIndex::Int8: return MagicNumber::Int8;
|
||||
case TypeIndex::Int16: return MagicNumber::Int16;
|
||||
case TypeIndex::Int32: return MagicNumber::Int32;
|
||||
case TypeIndex::Int64: return MagicNumber::Int64;
|
||||
case TypeIndex::Date: return MagicNumber::Date;
|
||||
case TypeIndex::DateTime: return MagicNumber::DateTime;
|
||||
case TypeIndex::DateTime64: return MagicNumber::DateTime64;
|
||||
case TypeIndex::Enum8: return MagicNumber::Enum8;
|
||||
case TypeIndex::Enum16: return MagicNumber::Enum16;
|
||||
case TypeIndex::Decimal32: return MagicNumber::Decimal32;
|
||||
case TypeIndex::Decimal64: return MagicNumber::Decimal64;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
throw Exception("Type is not supported by T64 codec: " + toString(UInt32(type_id)), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
TypeIndex deserializeTypeId(uint8_t serialized_type_id)
|
||||
{
|
||||
MagicNumber magic = static_cast<MagicNumber>(serialized_type_id);
|
||||
switch (magic)
|
||||
{
|
||||
case MagicNumber::UInt8: return TypeIndex::UInt8;
|
||||
case MagicNumber::UInt16: return TypeIndex::UInt16;
|
||||
case MagicNumber::UInt32: return TypeIndex::UInt32;
|
||||
case MagicNumber::UInt64: return TypeIndex::UInt64;
|
||||
case MagicNumber::Int8: return TypeIndex::Int8;
|
||||
case MagicNumber::Int16: return TypeIndex::Int16;
|
||||
case MagicNumber::Int32: return TypeIndex::Int32;
|
||||
case MagicNumber::Int64: return TypeIndex::Int64;
|
||||
case MagicNumber::Date: return TypeIndex::Date;
|
||||
case MagicNumber::DateTime: return TypeIndex::DateTime;
|
||||
case MagicNumber::DateTime64: return TypeIndex::DateTime64;
|
||||
case MagicNumber::Enum8: return TypeIndex::Enum8;
|
||||
case MagicNumber::Enum16: return TypeIndex::Enum16;
|
||||
case MagicNumber::Decimal32: return TypeIndex::Decimal32;
|
||||
case MagicNumber::Decimal64: return TypeIndex::Decimal64;
|
||||
}
|
||||
|
||||
throw Exception("Bad magic number in T64 codec: " + toString(UInt32(serialized_type_id)), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
|
||||
UInt8 codecId()
|
||||
{
|
||||
return static_cast<UInt8>(CompressionMethodByte::T64);
|
||||
@ -41,6 +113,7 @@ TypeIndex baseType(TypeIndex type_idx)
|
||||
return TypeIndex::Int32;
|
||||
case TypeIndex::Int64:
|
||||
case TypeIndex::Decimal64:
|
||||
case TypeIndex::DateTime64:
|
||||
return TypeIndex::Int64;
|
||||
case TypeIndex::UInt8:
|
||||
case TypeIndex::Enum8:
|
||||
@ -79,6 +152,7 @@ TypeIndex typeIdx(const DataTypePtr & data_type)
|
||||
case TypeIndex::Int32:
|
||||
case TypeIndex::UInt32:
|
||||
case TypeIndex::DateTime:
|
||||
case TypeIndex::DateTime64:
|
||||
case TypeIndex::Decimal32:
|
||||
case TypeIndex::Int64:
|
||||
case TypeIndex::UInt64:
|
||||
@ -490,7 +564,7 @@ void decompressData(const char * src, UInt32 src_size, char * dst, UInt32 uncomp
|
||||
|
||||
UInt32 CompressionCodecT64::doCompressData(const char * src, UInt32 src_size, char * dst) const
|
||||
{
|
||||
UInt8 cookie = static_cast<UInt8>(type_idx) | (static_cast<UInt8>(variant) << 7);
|
||||
UInt8 cookie = static_cast<UInt8>(serializeTypeId(type_idx)) | (static_cast<UInt8>(variant) << 7);
|
||||
memcpy(dst, &cookie, 1);
|
||||
dst += 1;
|
||||
|
||||
@ -529,7 +603,7 @@ void CompressionCodecT64::doDecompressData(const char * src, UInt32 src_size, ch
|
||||
src_size -= 1;
|
||||
|
||||
auto saved_variant = static_cast<Variant>(cookie >> 7);
|
||||
auto saved_type_id = static_cast<TypeIndex>(cookie & 0x7F);
|
||||
TypeIndex saved_type_id = deserializeTypeId(cookie & 0x7F);
|
||||
|
||||
switch (baseType(saved_type_id))
|
||||
{
|
||||
|
@ -158,8 +158,8 @@ public:
|
||||
|
||||
explicit BinaryDataAsSequenceOfValuesIterator(const Container & container_)
|
||||
: container(container_),
|
||||
data(&container[0]),
|
||||
data_end(reinterpret_cast<const char *>(data) + container.size()),
|
||||
data(container.data()),
|
||||
data_end(container.data() + container.size()),
|
||||
current_value(T{})
|
||||
{
|
||||
static_assert(sizeof(container[0]) == 1 && std::is_pod<std::decay_t<decltype(container[0])>>::value, "Only works on containers of byte-size PODs.");
|
||||
@ -789,12 +789,14 @@ auto FFand0Generator = []()
|
||||
};
|
||||
|
||||
|
||||
// Makes many sequences with generator, first sequence length is 1, second is 2... up to `sequences_count`.
|
||||
// Makes many sequences with generator, first sequence length is 0, second is 1..., third is 2 up to `sequences_count`.
|
||||
template <typename T, typename Generator>
|
||||
std::vector<CodecTestSequence> generatePyramidOfSequences(const size_t sequences_count, Generator && generator, const char* generator_name)
|
||||
{
|
||||
std::vector<CodecTestSequence> sequences;
|
||||
sequences.reserve(sequences_count);
|
||||
|
||||
sequences.push_back(makeSeq<T>()); // sequence of size 0
|
||||
for (size_t i = 1; i < sequences_count; ++i)
|
||||
{
|
||||
std::string name = generator_name + std::string(" from 0 to ") + std::to_string(i);
|
||||
|
@ -6,6 +6,7 @@
|
||||
#define DBMS_DEFAULT_HTTP_PORT 8123
|
||||
#define DBMS_DEFAULT_CONNECT_TIMEOUT_SEC 10
|
||||
#define DBMS_DEFAULT_CONNECT_TIMEOUT_WITH_FAILOVER_MS 50
|
||||
#define DBMS_DEFAULT_CONNECT_TIMEOUT_WITH_FAILOVER_SECURE_MS 100
|
||||
#define DBMS_DEFAULT_SEND_TIMEOUT_SEC 300
|
||||
#define DBMS_DEFAULT_RECEIVE_TIMEOUT_SEC 300
|
||||
/// Timeout for synchronous request-result protocol call (like Ping or TablesStatus).
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <Common/PODArray.h>
|
||||
#include <Core/Types.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Access/AccessControlManager.h>
|
||||
#include <Access/User.h>
|
||||
#include <IO/copyData.h>
|
||||
#include <IO/LimitReadBuffer.h>
|
||||
@ -952,7 +953,7 @@ public:
|
||||
throw Exception("Wrong size of auth response. Expected: " + std::to_string(Poco::SHA1Engine::DIGEST_SIZE) + " bytes, received: " + std::to_string(auth_response->size()) + " bytes.",
|
||||
ErrorCodes::UNKNOWN_EXCEPTION);
|
||||
|
||||
auto user = context.getUser(user_name);
|
||||
auto user = context.getAccessControlManager().getUser(user_name);
|
||||
|
||||
Poco::SHA1Engine::Digest double_sha1_value = user->authentication.getPasswordDoubleSHA1();
|
||||
assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE);
|
||||
|
@ -62,6 +62,7 @@ struct Settings : public SettingsCollection<Settings>
|
||||
M(SettingUInt64, interactive_delay, 100000, "The interval in microseconds to check if the request is cancelled, and to send progress info.", 0) \
|
||||
M(SettingSeconds, connect_timeout, DBMS_DEFAULT_CONNECT_TIMEOUT_SEC, "Connection timeout if there are no replicas.", 0) \
|
||||
M(SettingMilliseconds, connect_timeout_with_failover_ms, DBMS_DEFAULT_CONNECT_TIMEOUT_WITH_FAILOVER_MS, "Connection timeout for selecting first healthy replica.", 0) \
|
||||
M(SettingMilliseconds, connect_timeout_with_failover_secure_ms, DBMS_DEFAULT_CONNECT_TIMEOUT_WITH_FAILOVER_SECURE_MS, "Connection timeout for selecting first healthy replica (for secure connections).", 0) \
|
||||
M(SettingSeconds, receive_timeout, DBMS_DEFAULT_RECEIVE_TIMEOUT_SEC, "", 0) \
|
||||
M(SettingSeconds, send_timeout, DBMS_DEFAULT_SEND_TIMEOUT_SEC, "", 0) \
|
||||
M(SettingSeconds, tcp_keep_alive_timeout, 0, "The time in seconds the connection needs to remain idle before TCP starts sending keepalive probes", 0) \
|
||||
@ -203,6 +204,7 @@ struct Settings : public SettingsCollection<Settings>
|
||||
M(SettingUInt64, output_format_parquet_row_group_size, 1000000, "Row group size in rows.", 0) \
|
||||
M(SettingString, output_format_avro_codec, "", "Compression codec used for output. Possible values: 'null', 'deflate', 'snappy'.", 0) \
|
||||
M(SettingUInt64, output_format_avro_sync_interval, 16 * 1024, "Sync interval in bytes.", 0) \
|
||||
M(SettingBool, output_format_tsv_crlf_end_of_line, false, "If it is set true, end of line in TSV format will be \\r\\n instead of \\n.", 0) \
|
||||
\
|
||||
M(SettingBool, use_client_time_zone, false, "Use client timezone for interpreting DateTime string values, instead of adopting server timezone.", 0) \
|
||||
\
|
||||
@ -330,6 +332,7 @@ struct Settings : public SettingsCollection<Settings>
|
||||
M(SettingUInt64, max_memory_usage, 0, "Maximum memory usage for processing of single query. Zero means unlimited.", 0) \
|
||||
M(SettingUInt64, max_memory_usage_for_user, 0, "Maximum memory usage for processing all concurrently running queries for the user. Zero means unlimited.", 0) \
|
||||
M(SettingUInt64, max_memory_usage_for_all_queries, 0, "Maximum memory usage for processing all concurrently running queries on the server. Zero means unlimited.", 0) \
|
||||
M(SettingUInt64, memory_profiler_step, 0, "Every number of bytes the memory profiler will dump the allocating stacktrace. Zero means disabled memory profiler.", 0) \
|
||||
\
|
||||
M(SettingUInt64, max_network_bandwidth, 0, "The maximum speed of data exchange over the network in bytes per second for a query. Zero means unlimited.", 0) \
|
||||
M(SettingUInt64, max_network_bytes, 0, "The maximum number of bytes (compressed) to receive or transmit over the network for execution of the query.", 0) \
|
||||
@ -338,7 +341,7 @@ struct Settings : public SettingsCollection<Settings>
|
||||
M(SettingChar, format_csv_delimiter, ',', "The character to be considered as a delimiter in CSV data. If setting with a string, a string has to have a length of 1.", 0) \
|
||||
M(SettingBool, format_csv_allow_single_quotes, 1, "If it is set to true, allow strings in single quotes.", 0) \
|
||||
M(SettingBool, format_csv_allow_double_quotes, 1, "If it is set to true, allow strings in double quotes.", 0) \
|
||||
M(SettingBool, output_format_csv_crlf_end_of_line, false, "If it is set true, end of line will be \\r\\n instead of \\n.", 0) \
|
||||
M(SettingBool, output_format_csv_crlf_end_of_line, false, "If it is set true, end of line in CSV format will be \\r\\n instead of \\n.", 0) \
|
||||
M(SettingBool, input_format_csv_unquoted_null_literal_as_null, false, "Consider unquoted NULL literal as \\N", 0) \
|
||||
\
|
||||
M(SettingDateTimeInputFormat, date_time_input_format, FormatSettings::DateTimeInputFormat::Basic, "Method to read DateTime from text input formats. Possible values: 'basic' and 'best_effort'.", 0) \
|
||||
@ -390,6 +393,9 @@ struct Settings : public SettingsCollection<Settings>
|
||||
M(SettingUInt64, mutations_sync, 0, "Wait for synchronous execution of ALTER TABLE UPDATE/DELETE queries (mutations). 0 - execute asynchronously. 1 - wait current server. 2 - wait all replicas if they exist.", 0) \
|
||||
M(SettingBool, optimize_if_chain_to_miltiif, false, "Replace if(cond1, then1, if(cond2, ...)) chains to multiIf. Currently it's not beneficial for numeric types.", 0) \
|
||||
M(SettingBool, allow_experimental_alter_materialized_view_structure, false, "Allow atomic alter on Materialized views. Work in progress.", 0) \
|
||||
M(SettingBool, enable_early_constant_folding, true, "Enable query optimization where we analyze function and subqueries results and rewrite query if there're constants there", 0) \
|
||||
\
|
||||
M(SettingBool, partial_revokes, false, "Makes it possible to revoke privileges partially.", 0) \
|
||||
\
|
||||
/** Obsolete settings that do nothing but left for compatibility reasons. Remove each one after half a year of obsolescence. */ \
|
||||
\
|
||||
|
@ -421,14 +421,7 @@ void SettingURI::set(const Field & x)
|
||||
|
||||
void SettingURI::set(const String & x)
|
||||
{
|
||||
try {
|
||||
Poco::URI uri(x);
|
||||
set(uri);
|
||||
}
|
||||
catch (const Poco::Exception& e)
|
||||
{
|
||||
throw Exception{Exception::CreateFromPoco, e};
|
||||
}
|
||||
set(Poco::URI(x));
|
||||
}
|
||||
|
||||
void SettingURI::serialize(WriteBuffer & buf, SettingsBinaryFormat) const
|
||||
|
@ -14,6 +14,7 @@ namespace DB
|
||||
|
||||
struct Null {};
|
||||
|
||||
/// @note Except explicitly described you should not assume on TypeIndex numbers and/or their orders in this enum.
|
||||
enum class TypeIndex
|
||||
{
|
||||
Nothing = 0,
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <DataStreams/ExpressionBlockInputStream.h>
|
||||
#include <DataStreams/CheckConstraintsBlockOutputStream.h>
|
||||
#include <Parsers/formatAST.h>
|
||||
#include <Interpreters/ExpressionActions.h>
|
||||
#include <Columns/ColumnsCommon.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Common/assert_cast.h>
|
||||
|
@ -34,7 +34,7 @@ void ParallelParsingBlockInputStream::segmentatorThreadFunction()
|
||||
|
||||
unit.is_last = !have_more_data;
|
||||
unit.status = READY_TO_PARSE;
|
||||
scheduleParserThreadForUnitWithNumber(current_unit_number);
|
||||
scheduleParserThreadForUnitWithNumber(segmentator_ticket_number);
|
||||
++segmentator_ticket_number;
|
||||
|
||||
if (!have_more_data)
|
||||
@ -49,12 +49,13 @@ void ParallelParsingBlockInputStream::segmentatorThreadFunction()
|
||||
}
|
||||
}
|
||||
|
||||
void ParallelParsingBlockInputStream::parserThreadFunction(size_t current_unit_number)
|
||||
void ParallelParsingBlockInputStream::parserThreadFunction(size_t current_ticket_number)
|
||||
{
|
||||
try
|
||||
{
|
||||
setThreadName("ChunkParser");
|
||||
|
||||
const auto current_unit_number = current_ticket_number % processing_units.size();
|
||||
auto & unit = processing_units[current_unit_number];
|
||||
|
||||
/*
|
||||
@ -64,9 +65,9 @@ void ParallelParsingBlockInputStream::parserThreadFunction(size_t current_unit_n
|
||||
* can use it from multiple threads simultaneously.
|
||||
*/
|
||||
ReadBuffer read_buffer(unit.segment.data(), unit.segment.size(), 0);
|
||||
auto parser = std::make_unique<InputStreamFromInputFormat>(
|
||||
input_processor_creator(read_buffer, header,
|
||||
row_input_format_params, format_settings));
|
||||
auto format = input_processor_creator(read_buffer, header, row_input_format_params, format_settings);
|
||||
format->setCurrentUnitNumber(current_ticket_number);
|
||||
auto parser = std::make_unique<InputStreamFromInputFormat>(std::move(format));
|
||||
|
||||
unit.block_ext.block.clear();
|
||||
unit.block_ext.block_missing_values.clear();
|
||||
|
@ -213,9 +213,9 @@ private:
|
||||
std::deque<ProcessingUnit> processing_units;
|
||||
|
||||
|
||||
void scheduleParserThreadForUnitWithNumber(size_t unit_number)
|
||||
void scheduleParserThreadForUnitWithNumber(size_t ticket_number)
|
||||
{
|
||||
pool.scheduleOrThrowOnError(std::bind(&ParallelParsingBlockInputStream::parserThreadFunction, this, unit_number));
|
||||
pool.scheduleOrThrowOnError(std::bind(&ParallelParsingBlockInputStream::parserThreadFunction, this, ticket_number));
|
||||
}
|
||||
|
||||
void finishAndWait()
|
||||
|
@ -8,7 +8,6 @@
|
||||
#include <Parsers/ASTInsertQuery.h>
|
||||
#include <Common/CurrentThread.h>
|
||||
#include <Common/setThreadName.h>
|
||||
#include <Common/getNumberOfPhysicalCPUCores.h>
|
||||
#include <Common/ThreadPool.h>
|
||||
#include <Storages/MergeTree/ReplicatedMergeTreeBlockOutputStream.h>
|
||||
#include <Storages/StorageValues.h>
|
||||
@ -51,8 +50,10 @@ PushingToViewsBlockOutputStream::PushingToViewsBlockOutputStream(
|
||||
ASTPtr query;
|
||||
BlockOutputStreamPtr out;
|
||||
|
||||
if (auto * materialized_view = dynamic_cast<const StorageMaterializedView *>(dependent_table.get()))
|
||||
if (auto * materialized_view = dynamic_cast<StorageMaterializedView *>(dependent_table.get()))
|
||||
{
|
||||
addTableLock(materialized_view->lockStructureForShare(true, context.getInitialQueryId()));
|
||||
|
||||
StoragePtr inner_table = materialized_view->getTargetTable();
|
||||
auto inner_table_id = inner_table->getStorageID();
|
||||
query = materialized_view->getInnerQuery();
|
||||
|
@ -621,6 +621,12 @@ inline bool isStringOrFixedString(const T & data_type)
|
||||
return WhichDataType(data_type).isStringOrFixedString();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool isNotCreatable(const T & data_type)
|
||||
{
|
||||
WhichDataType which(data_type);
|
||||
return which.isNothing() || which.isFunction() || which.isSet();
|
||||
}
|
||||
|
||||
inline bool isNotDecimalButComparableToDecimal(const DataTypePtr & data_type)
|
||||
{
|
||||
|
@ -252,20 +252,23 @@ void DatabaseOrdinary::alterTable(
|
||||
ast->replace(ast_create_query.select, metadata.select);
|
||||
}
|
||||
|
||||
ASTStorage & storage_ast = *ast_create_query.storage;
|
||||
/// ORDER BY may change, but cannot appear, it's required construction
|
||||
if (metadata.order_by_ast && storage_ast.order_by)
|
||||
storage_ast.set(storage_ast.order_by, metadata.order_by_ast);
|
||||
/// MaterializedView is one type of CREATE query without storage.
|
||||
if (ast_create_query.storage)
|
||||
{
|
||||
ASTStorage & storage_ast = *ast_create_query.storage;
|
||||
/// ORDER BY may change, but cannot appear, it's required construction
|
||||
if (metadata.order_by_ast && storage_ast.order_by)
|
||||
storage_ast.set(storage_ast.order_by, metadata.order_by_ast);
|
||||
|
||||
if (metadata.primary_key_ast)
|
||||
storage_ast.set(storage_ast.primary_key, metadata.primary_key_ast);
|
||||
if (metadata.primary_key_ast)
|
||||
storage_ast.set(storage_ast.primary_key, metadata.primary_key_ast);
|
||||
|
||||
if (metadata.ttl_for_table_ast)
|
||||
storage_ast.set(storage_ast.ttl_table, metadata.ttl_for_table_ast);
|
||||
|
||||
if (metadata.settings_ast)
|
||||
storage_ast.set(storage_ast.settings, metadata.settings_ast);
|
||||
if (metadata.ttl_for_table_ast)
|
||||
storage_ast.set(storage_ast.ttl_table, metadata.ttl_for_table_ast);
|
||||
|
||||
if (metadata.settings_ast)
|
||||
storage_ast.set(storage_ast.settings, metadata.settings_ast);
|
||||
}
|
||||
|
||||
statement = getObjectDefinitionFromCreateQuery(ast);
|
||||
{
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <ext/range.h>
|
||||
#include <ext/size.h>
|
||||
#include <Common/setThreadName.h>
|
||||
#include "CacheDictionary.inc.h"
|
||||
#include "DictionaryBlockInputStream.h"
|
||||
#include "DictionaryFactory.h"
|
||||
@ -61,24 +62,48 @@ CacheDictionary::CacheDictionary(
|
||||
const std::string & name_,
|
||||
const DictionaryStructure & dict_struct_,
|
||||
DictionarySourcePtr source_ptr_,
|
||||
const DictionaryLifetime dict_lifetime_,
|
||||
const size_t size_)
|
||||
DictionaryLifetime dict_lifetime_,
|
||||
size_t size_,
|
||||
bool allow_read_expired_keys_,
|
||||
size_t max_update_queue_size_,
|
||||
size_t update_queue_push_timeout_milliseconds_,
|
||||
size_t max_threads_for_updates_)
|
||||
: database(database_)
|
||||
, name(name_)
|
||||
, full_name{database_.empty() ? name_ : (database_ + "." + name_)}
|
||||
, dict_struct(dict_struct_)
|
||||
, source_ptr{std::move(source_ptr_)}
|
||||
, dict_lifetime(dict_lifetime_)
|
||||
, allow_read_expired_keys(allow_read_expired_keys_)
|
||||
, max_update_queue_size(max_update_queue_size_)
|
||||
, update_queue_push_timeout_milliseconds(update_queue_push_timeout_milliseconds_)
|
||||
, max_threads_for_updates(max_threads_for_updates_)
|
||||
, log(&Logger::get("ExternalDictionaries"))
|
||||
, size{roundUpToPowerOfTwoOrZero(std::max(size_, size_t(max_collision_length)))}
|
||||
, size_overlap_mask{this->size - 1}
|
||||
, cells{this->size}
|
||||
, rnd_engine(randomSeed())
|
||||
, update_queue(max_update_queue_size_)
|
||||
, update_pool(max_threads_for_updates)
|
||||
{
|
||||
if (!this->source_ptr->supportsSelectiveLoad())
|
||||
throw Exception{full_name + ": source cannot be used with CacheDictionary", ErrorCodes::UNSUPPORTED_METHOD};
|
||||
|
||||
createAttributes();
|
||||
for (size_t i = 0; i < max_threads_for_updates; ++i)
|
||||
update_pool.scheduleOrThrowOnError([this] { updateThreadFunction(); });
|
||||
}
|
||||
|
||||
CacheDictionary::~CacheDictionary()
|
||||
{
|
||||
finished = true;
|
||||
update_queue.clear();
|
||||
for (size_t i = 0; i < max_threads_for_updates; ++i)
|
||||
{
|
||||
auto empty_finishing_ptr = std::make_shared<UpdateUnit>(std::vector<Key>());
|
||||
update_queue.push(empty_finishing_ptr);
|
||||
}
|
||||
update_pool.wait();
|
||||
}
|
||||
|
||||
|
||||
@ -275,10 +300,16 @@ CacheDictionary::FindResult CacheDictionary::findCellIdx(const Key & id, const C
|
||||
|
||||
void CacheDictionary::has(const PaddedPODArray<Key> & ids, PaddedPODArray<UInt8> & out) const
|
||||
{
|
||||
/// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
|
||||
std::unordered_map<Key, std::vector<size_t>> outdated_ids;
|
||||
/// There are three types of ids.
|
||||
/// - Valid ids. These ids are presented in local cache and their lifetime is not expired.
|
||||
/// - CacheExpired ids. Ids that are in local cache, but their values are rotted (lifetime is expired).
|
||||
/// - CacheNotFound ids. We have to go to external storage to know its value.
|
||||
|
||||
size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0;
|
||||
/// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
|
||||
std::unordered_map<Key, std::vector<size_t>> cache_expired_ids;
|
||||
std::unordered_map<Key, std::vector<size_t>> cache_not_found_ids;
|
||||
|
||||
size_t cache_hit = 0;
|
||||
|
||||
const auto rows = ext::size(ids);
|
||||
{
|
||||
@ -291,49 +322,97 @@ void CacheDictionary::has(const PaddedPODArray<Key> & ids, PaddedPODArray<UInt8>
|
||||
const auto id = ids[row];
|
||||
const auto find_result = findCellIdx(id, now);
|
||||
const auto & cell_idx = find_result.cell_idx;
|
||||
|
||||
auto insert_to_answer_routine = [&] ()
|
||||
{
|
||||
out[row] = !cells[cell_idx].isDefault();
|
||||
};
|
||||
|
||||
if (!find_result.valid)
|
||||
{
|
||||
outdated_ids[id].push_back(row);
|
||||
if (find_result.outdated)
|
||||
++cache_expired;
|
||||
{
|
||||
cache_expired_ids[id].push_back(row);
|
||||
|
||||
if (allow_read_expired_keys)
|
||||
insert_to_answer_routine();
|
||||
}
|
||||
else
|
||||
++cache_not_found;
|
||||
{
|
||||
cache_not_found_ids[id].push_back(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
++cache_hit;
|
||||
const auto & cell = cells[cell_idx];
|
||||
out[row] = !cell.isDefault();
|
||||
insert_to_answer_routine();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired);
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found);
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired_ids.size());
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found_ids.size());
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit);
|
||||
|
||||
query_count.fetch_add(rows, std::memory_order_relaxed);
|
||||
hit_count.fetch_add(rows - outdated_ids.size(), std::memory_order_release);
|
||||
hit_count.fetch_add(rows - cache_expired_ids.size() - cache_not_found_ids.size(), std::memory_order_release);
|
||||
|
||||
if (outdated_ids.empty())
|
||||
return;
|
||||
if (cache_not_found_ids.empty())
|
||||
{
|
||||
/// Nothing to update - return;
|
||||
if (cache_expired_ids.empty())
|
||||
return;
|
||||
|
||||
std::vector<Key> required_ids(outdated_ids.size());
|
||||
std::transform(std::begin(outdated_ids), std::end(outdated_ids), std::begin(required_ids), [](auto & pair) { return pair.first; });
|
||||
|
||||
/// request new values
|
||||
update(
|
||||
required_ids,
|
||||
[&](const auto id, const auto)
|
||||
if (allow_read_expired_keys)
|
||||
{
|
||||
for (const auto row : outdated_ids[id])
|
||||
out[row] = true;
|
||||
},
|
||||
[&](const auto id, const auto)
|
||||
{
|
||||
for (const auto row : outdated_ids[id])
|
||||
out[row] = false;
|
||||
});
|
||||
std::vector<Key> required_expired_ids;
|
||||
required_expired_ids.reserve(cache_expired_ids.size());
|
||||
std::transform(
|
||||
std::begin(cache_expired_ids), std::end(cache_expired_ids),
|
||||
std::back_inserter(required_expired_ids), [](auto & pair) { return pair.first; });
|
||||
|
||||
/// Callbacks are empty because we don't want to receive them after an unknown period of time.
|
||||
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_expired_ids);
|
||||
|
||||
tryPushToUpdateQueueOrThrow(update_unit_ptr);
|
||||
/// Update is async - no need to wait.
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// At this point we have two situations.
|
||||
/// There may be both types of keys: cache_expired_ids and cache_not_found_ids.
|
||||
/// We will update them all synchronously.
|
||||
|
||||
std::vector<Key> required_ids;
|
||||
required_ids.reserve(cache_not_found_ids.size() + cache_expired_ids.size());
|
||||
std::transform(
|
||||
std::begin(cache_not_found_ids), std::end(cache_not_found_ids),
|
||||
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
|
||||
std::transform(
|
||||
std::begin(cache_expired_ids), std::end(cache_expired_ids),
|
||||
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
|
||||
|
||||
auto on_cell_updated = [&] (const Key id, const size_t)
|
||||
{
|
||||
for (const auto row : cache_not_found_ids[id])
|
||||
out[row] = true;
|
||||
for (const auto row : cache_expired_ids[id])
|
||||
out[row] = true;
|
||||
};
|
||||
|
||||
auto on_id_not_found = [&] (const Key id, const size_t)
|
||||
{
|
||||
for (const auto row : cache_not_found_ids[id])
|
||||
out[row] = false;
|
||||
for (const auto row : cache_expired_ids[id])
|
||||
out[row] = true;
|
||||
};
|
||||
|
||||
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_ids, on_cell_updated, on_id_not_found);
|
||||
|
||||
tryPushToUpdateQueueOrThrow(update_unit_ptr);
|
||||
waitForCurrentUpdateFinish(update_unit_ptr);
|
||||
}
|
||||
|
||||
|
||||
@ -590,7 +669,8 @@ void registerDictionaryCache(DictionaryFactory & factory)
|
||||
DictionarySourcePtr source_ptr) -> DictionaryPtr
|
||||
{
|
||||
if (dict_struct.key)
|
||||
throw Exception{"'key' is not supported for dictionary of layout 'cache'", ErrorCodes::UNSUPPORTED_METHOD};
|
||||
throw Exception{"'key' is not supported for dictionary of layout 'cache'",
|
||||
ErrorCodes::UNSUPPORTED_METHOD};
|
||||
|
||||
if (dict_struct.range_min || dict_struct.range_max)
|
||||
throw Exception{full_name
|
||||
@ -598,9 +678,11 @@ void registerDictionaryCache(DictionaryFactory & factory)
|
||||
"for a dictionary of layout 'range_hashed'",
|
||||
ErrorCodes::BAD_ARGUMENTS};
|
||||
const auto & layout_prefix = config_prefix + ".layout";
|
||||
const auto size = config.getInt(layout_prefix + ".cache.size_in_cells");
|
||||
|
||||
const size_t size = config.getUInt64(layout_prefix + ".cache.size_in_cells");
|
||||
if (size == 0)
|
||||
throw Exception{full_name + ": dictionary of layout 'cache' cannot have 0 cells", ErrorCodes::TOO_SMALL_BUFFER_SIZE};
|
||||
throw Exception{full_name + ": dictionary of layout 'cache' cannot have 0 cells",
|
||||
ErrorCodes::TOO_SMALL_BUFFER_SIZE};
|
||||
|
||||
const bool require_nonempty = config.getBool(config_prefix + ".require_nonempty", false);
|
||||
if (require_nonempty)
|
||||
@ -610,10 +692,284 @@ void registerDictionaryCache(DictionaryFactory & factory)
|
||||
const String database = config.getString(config_prefix + ".database", "");
|
||||
const String name = config.getString(config_prefix + ".name");
|
||||
const DictionaryLifetime dict_lifetime{config, config_prefix + ".lifetime"};
|
||||
return std::make_unique<CacheDictionary>(database, name, dict_struct, std::move(source_ptr), dict_lifetime, size);
|
||||
|
||||
const size_t max_update_queue_size =
|
||||
config.getUInt64(layout_prefix + ".cache.max_update_queue_size", 100000);
|
||||
if (max_update_queue_size == 0)
|
||||
throw Exception{name + ": dictionary of layout 'cache' cannot have empty update queue of size 0",
|
||||
ErrorCodes::TOO_SMALL_BUFFER_SIZE};
|
||||
|
||||
const bool allow_read_expired_keys =
|
||||
config.getBool(layout_prefix + ".cache.allow_read_expired_keys", false);
|
||||
|
||||
const size_t update_queue_push_timeout_milliseconds =
|
||||
config.getUInt64(layout_prefix + ".cache.update_queue_push_timeout_milliseconds", 10);
|
||||
if (update_queue_push_timeout_milliseconds < 10)
|
||||
throw Exception{name + ": dictionary of layout 'cache' have too little update_queue_push_timeout",
|
||||
ErrorCodes::BAD_ARGUMENTS};
|
||||
|
||||
const size_t max_threads_for_updates =
|
||||
config.getUInt64(layout_prefix + ".max_threads_for_updates", 4);
|
||||
if (max_threads_for_updates == 0)
|
||||
throw Exception{name + ": dictionary of layout 'cache' cannot have zero threads for updates.",
|
||||
ErrorCodes::BAD_ARGUMENTS};
|
||||
|
||||
return std::make_unique<CacheDictionary>(
|
||||
database, name, dict_struct, std::move(source_ptr), dict_lifetime, size,
|
||||
allow_read_expired_keys, max_update_queue_size, update_queue_push_timeout_milliseconds,
|
||||
max_threads_for_updates);
|
||||
};
|
||||
factory.registerLayout("cache", create_layout, false);
|
||||
}
|
||||
|
||||
void CacheDictionary::updateThreadFunction()
|
||||
{
|
||||
setThreadName("AsyncUpdater");
|
||||
while (!finished)
|
||||
{
|
||||
UpdateUnitPtr first_popped;
|
||||
update_queue.pop(first_popped);
|
||||
|
||||
if (finished)
|
||||
break;
|
||||
|
||||
/// Here we pop as many unit pointers from update queue as we can.
|
||||
/// We fix current size to avoid livelock (or too long waiting),
|
||||
/// when this thread pops from the queue and other threads push to the queue.
|
||||
const size_t current_queue_size = update_queue.size();
|
||||
|
||||
if (current_queue_size > 0)
|
||||
LOG_TRACE(log, "Performing bunch of keys update in cache dictionary with "
|
||||
<< current_queue_size + 1 << " keys");
|
||||
|
||||
std::vector<UpdateUnitPtr> update_request;
|
||||
update_request.reserve(current_queue_size + 1);
|
||||
update_request.emplace_back(first_popped);
|
||||
|
||||
UpdateUnitPtr current_unit_ptr;
|
||||
|
||||
while (update_request.size() && update_queue.tryPop(current_unit_ptr))
|
||||
update_request.emplace_back(std::move(current_unit_ptr));
|
||||
|
||||
BunchUpdateUnit bunch_update_unit(update_request);
|
||||
|
||||
try
|
||||
{
|
||||
/// Update a bunch of ids.
|
||||
update(bunch_update_unit);
|
||||
|
||||
/// Notify all threads about finished updating the bunch of ids
|
||||
/// where their own ids were included.
|
||||
std::unique_lock<std::mutex> lock(update_mutex);
|
||||
|
||||
for (auto & unit_ptr: update_request)
|
||||
unit_ptr->is_done = true;
|
||||
|
||||
is_update_finished.notify_all();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(update_mutex);
|
||||
/// It is a big trouble, because one bad query can make other threads fail with not relative exception.
|
||||
/// So at this point all threads (and queries) will receive the same exception.
|
||||
for (auto & unit_ptr: update_request)
|
||||
unit_ptr->current_exception = std::current_exception();
|
||||
|
||||
is_update_finished.notify_all();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CacheDictionary::waitForCurrentUpdateFinish(UpdateUnitPtr & update_unit_ptr) const
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(update_mutex);
|
||||
|
||||
/*
|
||||
* We wait here without any timeout to avoid SEGFAULT's.
|
||||
* Consider timeout for wait had expired and main query's thread ended with exception
|
||||
* or some other error. But the UpdateUnit with callbacks is left in the queue.
|
||||
* It has these callback that capture god knows what from the current thread
|
||||
* (most of the variables lies on the stack of finished thread) that
|
||||
* intended to do a synchronous update. AsyncUpdate thread can touch deallocated memory and explode.
|
||||
* */
|
||||
is_update_finished.wait(
|
||||
lock,
|
||||
[&] {return update_unit_ptr->is_done || update_unit_ptr->current_exception; });
|
||||
|
||||
if (update_unit_ptr->current_exception)
|
||||
std::rethrow_exception(update_unit_ptr->current_exception);
|
||||
}
|
||||
|
||||
void CacheDictionary::tryPushToUpdateQueueOrThrow(UpdateUnitPtr & update_unit_ptr) const
|
||||
{
|
||||
if (!update_queue.tryPush(update_unit_ptr, update_queue_push_timeout_milliseconds))
|
||||
throw DB::Exception(
|
||||
"Cannot push to internal update queue in dictionary " + getFullName() + ". Timelimit of " +
|
||||
std::to_string(update_queue_push_timeout_milliseconds) + " ms. exceeded. Current queue size is " +
|
||||
std::to_string(update_queue.size()), ErrorCodes::CACHE_DICTIONARY_UPDATE_FAIL);
|
||||
}
|
||||
|
||||
void CacheDictionary::update(BunchUpdateUnit & bunch_update_unit) const
|
||||
{
|
||||
CurrentMetrics::Increment metric_increment{CurrentMetrics::DictCacheRequests};
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequested, bunch_update_unit.getRequestedIds().size());
|
||||
|
||||
std::unordered_map<Key, UInt8> remaining_ids{bunch_update_unit.getRequestedIds().size()};
|
||||
for (const auto id : bunch_update_unit.getRequestedIds())
|
||||
remaining_ids.insert({id, 0});
|
||||
|
||||
const auto now = std::chrono::system_clock::now();
|
||||
|
||||
if (now > backoff_end_time.load())
|
||||
{
|
||||
try
|
||||
{
|
||||
if (error_count)
|
||||
{
|
||||
/// Recover after error: we have to clone the source here because
|
||||
/// it could keep connections which should be reset after error.
|
||||
source_ptr = source_ptr->clone();
|
||||
}
|
||||
|
||||
Stopwatch watch;
|
||||
auto stream = source_ptr->loadIds(bunch_update_unit.getRequestedIds());
|
||||
|
||||
const ProfilingScopedWriteRWLock write_lock{rw_lock, ProfileEvents::DictCacheLockWriteNs};
|
||||
|
||||
stream->readPrefix();
|
||||
while (const auto block = stream->read())
|
||||
{
|
||||
const auto id_column = typeid_cast<const ColumnUInt64 *>(block.safeGetByPosition(0).column.get());
|
||||
if (!id_column)
|
||||
throw Exception{name + ": id column has type different from UInt64.", ErrorCodes::TYPE_MISMATCH};
|
||||
|
||||
const auto & ids = id_column->getData();
|
||||
|
||||
/// cache column pointers
|
||||
const auto column_ptrs = ext::map<std::vector>(
|
||||
ext::range(0, attributes.size()), [&block](size_t i) { return block.safeGetByPosition(i + 1).column.get(); });
|
||||
|
||||
for (const auto i : ext::range(0, ids.size()))
|
||||
{
|
||||
const auto id = ids[i];
|
||||
|
||||
const auto find_result = findCellIdx(id, now);
|
||||
const auto & cell_idx = find_result.cell_idx;
|
||||
|
||||
auto & cell = cells[cell_idx];
|
||||
|
||||
for (const auto attribute_idx : ext::range(0, attributes.size()))
|
||||
{
|
||||
const auto & attribute_column = *column_ptrs[attribute_idx];
|
||||
auto & attribute = attributes[attribute_idx];
|
||||
|
||||
setAttributeValue(attribute, cell_idx, attribute_column[i]);
|
||||
}
|
||||
|
||||
/// if cell id is zero and zero does not map to this cell, then the cell is unused
|
||||
if (cell.id == 0 && cell_idx != zero_cell_idx)
|
||||
element_count.fetch_add(1, std::memory_order_relaxed);
|
||||
|
||||
cell.id = id;
|
||||
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
|
||||
{
|
||||
std::uniform_int_distribution<UInt64> distribution{dict_lifetime.min_sec, dict_lifetime.max_sec};
|
||||
cell.setExpiresAt(now + std::chrono::seconds{distribution(rnd_engine)});
|
||||
}
|
||||
else
|
||||
cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
|
||||
|
||||
|
||||
bunch_update_unit.informCallersAboutPresentId(id, cell_idx);
|
||||
/// mark corresponding id as found
|
||||
remaining_ids[id] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
stream->readSuffix();
|
||||
|
||||
error_count = 0;
|
||||
last_exception = std::exception_ptr{};
|
||||
backoff_end_time = std::chrono::system_clock::time_point{};
|
||||
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheRequestTimeNs, watch.elapsed());
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
++error_count;
|
||||
last_exception = std::current_exception();
|
||||
backoff_end_time = now + std::chrono::seconds(calculateDurationWithBackoff(rnd_engine, error_count));
|
||||
|
||||
tryLogException(last_exception, log, "Could not update cache dictionary '" + getFullName() +
|
||||
"', next update is scheduled at " + ext::to_string(backoff_end_time.load()));
|
||||
}
|
||||
}
|
||||
|
||||
size_t not_found_num = 0, found_num = 0;
|
||||
|
||||
const ProfilingScopedWriteRWLock write_lock{rw_lock, ProfileEvents::DictCacheLockWriteNs};
|
||||
|
||||
/// Check which ids have not been found and require setting null_value
|
||||
for (const auto & id_found_pair : remaining_ids)
|
||||
{
|
||||
if (id_found_pair.second)
|
||||
{
|
||||
++found_num;
|
||||
continue;
|
||||
}
|
||||
++not_found_num;
|
||||
|
||||
const auto id = id_found_pair.first;
|
||||
|
||||
const auto find_result = findCellIdx(id, now);
|
||||
const auto & cell_idx = find_result.cell_idx;
|
||||
auto & cell = cells[cell_idx];
|
||||
|
||||
if (error_count)
|
||||
{
|
||||
if (find_result.outdated)
|
||||
{
|
||||
/// We have expired data for that `id` so we can continue using it.
|
||||
bool was_default = cell.isDefault();
|
||||
cell.setExpiresAt(backoff_end_time);
|
||||
if (was_default)
|
||||
cell.setDefault();
|
||||
if (was_default)
|
||||
bunch_update_unit.informCallersAboutAbsentId(id, cell_idx);
|
||||
else
|
||||
bunch_update_unit.informCallersAboutPresentId(id, cell_idx);
|
||||
continue;
|
||||
}
|
||||
/// We don't have expired data for that `id` so all we can do is to rethrow `last_exception`.
|
||||
std::rethrow_exception(last_exception);
|
||||
}
|
||||
|
||||
/// Check if cell had not been occupied before and increment element counter if it hadn't
|
||||
if (cell.id == 0 && cell_idx != zero_cell_idx)
|
||||
element_count.fetch_add(1, std::memory_order_relaxed);
|
||||
|
||||
cell.id = id;
|
||||
|
||||
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
|
||||
{
|
||||
std::uniform_int_distribution<UInt64> distribution{dict_lifetime.min_sec, dict_lifetime.max_sec};
|
||||
cell.setExpiresAt(now + std::chrono::seconds{distribution(rnd_engine)});
|
||||
}
|
||||
else
|
||||
cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
|
||||
|
||||
/// Set null_value for each attribute
|
||||
cell.setDefault();
|
||||
for (auto & attribute : attributes)
|
||||
setDefaultAttributeValue(attribute, cell_idx);
|
||||
|
||||
/// inform caller that the cell has not been found
|
||||
bunch_update_unit.informCallersAboutAbsentId(id, cell_idx);
|
||||
}
|
||||
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequestedMiss, not_found_num);
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequestedFound, found_num);
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheRequests);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -4,12 +4,16 @@
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
#include <common/logger_useful.h>
|
||||
#include <Columns/ColumnDecimal.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <Common/ThreadPool.h>
|
||||
#include <Common/ConcurrentBoundedQueue.h>
|
||||
#include <pcg_random.hpp>
|
||||
#include <Common/ArenaWithFreeLists.h>
|
||||
#include <Common/CurrentMetrics.h>
|
||||
@ -21,6 +25,22 @@
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int CACHE_DICTIONARY_UPDATE_FAIL;
|
||||
}
|
||||
|
||||
/*
|
||||
*
|
||||
* This dictionary is stored in a cache that has a fixed number of cells.
|
||||
* These cells contain frequently used elements.
|
||||
* When searching for a dictionary, the cache is searched first and special heuristic is used:
|
||||
* while looking for the key, we take a look only at max_collision_length elements.
|
||||
* So, our cache is not perfect. It has errors like "the key is in cache, but the cache says that it does not".
|
||||
* And in this case we simply ask external source for the key which is faster.
|
||||
* You have to keep this logic in mind.
|
||||
* */
|
||||
class CacheDictionary final : public IDictionary
|
||||
{
|
||||
public:
|
||||
@ -29,8 +49,14 @@ public:
|
||||
const std::string & name_,
|
||||
const DictionaryStructure & dict_struct_,
|
||||
DictionarySourcePtr source_ptr_,
|
||||
const DictionaryLifetime dict_lifetime_,
|
||||
const size_t size_);
|
||||
DictionaryLifetime dict_lifetime_,
|
||||
size_t size_,
|
||||
bool allow_read_expired_keys_,
|
||||
size_t max_update_queue_size_,
|
||||
size_t update_queue_push_timeout_milliseconds_,
|
||||
size_t max_threads_for_updates);
|
||||
|
||||
~CacheDictionary() override;
|
||||
|
||||
const std::string & getDatabase() const override { return database; }
|
||||
const std::string & getName() const override { return name; }
|
||||
@ -55,7 +81,10 @@ public:
|
||||
|
||||
std::shared_ptr<const IExternalLoadable> clone() const override
|
||||
{
|
||||
return std::make_shared<CacheDictionary>(database, name, dict_struct, source_ptr->clone(), dict_lifetime, size);
|
||||
return std::make_shared<CacheDictionary>(
|
||||
database, name, dict_struct, source_ptr->clone(), dict_lifetime, size,
|
||||
allow_read_expired_keys, max_update_queue_size,
|
||||
update_queue_push_timeout_milliseconds, max_threads_for_updates);
|
||||
}
|
||||
|
||||
const IDictionarySource * getSource() const override { return source_ptr.get(); }
|
||||
@ -230,9 +259,6 @@ private:
|
||||
template <typename DefaultGetter>
|
||||
void getItemsString(Attribute & attribute, const PaddedPODArray<Key> & ids, ColumnString * out, DefaultGetter && get_default) const;
|
||||
|
||||
template <typename PresentIdHandler, typename AbsentIdHandler>
|
||||
void update(const std::vector<Key> & requested_ids, PresentIdHandler && on_cell_updated, AbsentIdHandler && on_id_not_found) const;
|
||||
|
||||
PaddedPODArray<Key> getCachedIds() const;
|
||||
|
||||
bool isEmptyCell(const UInt64 idx) const;
|
||||
@ -263,6 +289,11 @@ private:
|
||||
const DictionaryStructure dict_struct;
|
||||
mutable DictionarySourcePtr source_ptr;
|
||||
const DictionaryLifetime dict_lifetime;
|
||||
const bool allow_read_expired_keys;
|
||||
const size_t max_update_queue_size;
|
||||
const size_t update_queue_push_timeout_milliseconds;
|
||||
const size_t max_threads_for_updates;
|
||||
|
||||
Logger * const log;
|
||||
|
||||
mutable std::shared_mutex rw_lock;
|
||||
@ -284,8 +315,8 @@ private:
|
||||
std::unique_ptr<ArenaWithFreeLists> string_arena;
|
||||
|
||||
mutable std::exception_ptr last_exception;
|
||||
mutable size_t error_count = 0;
|
||||
mutable std::chrono::system_clock::time_point backoff_end_time;
|
||||
mutable std::atomic<size_t> error_count = 0;
|
||||
mutable std::atomic<std::chrono::system_clock::time_point> backoff_end_time{std::chrono::system_clock::time_point{}};
|
||||
|
||||
mutable pcg64 rnd_engine;
|
||||
|
||||
@ -293,6 +324,166 @@ private:
|
||||
mutable std::atomic<size_t> element_count{0};
|
||||
mutable std::atomic<size_t> hit_count{0};
|
||||
mutable std::atomic<size_t> query_count{0};
|
||||
};
|
||||
|
||||
/// Field and methods correlated with update expired and not found keys
|
||||
|
||||
using PresentIdHandler = std::function<void(Key, size_t)>;
|
||||
using AbsentIdHandler = std::function<void(Key, size_t)>;
|
||||
|
||||
/*
|
||||
* Disclaimer: this comment is written not for fun.
|
||||
*
|
||||
* How the update goes: we basically have a method like get(keys)->values. Values are cached, so sometimes we
|
||||
* can return them from the cache. For values not in cache, we query them from the dictionary, and add to the
|
||||
* cache. The cache is lossy, so we can't expect it to store all the keys, and we store them separately. Normally,
|
||||
* they would be passed as a return value of get(), but for Unknown Reasons the dictionaries use a baroque
|
||||
* interface where get() accepts two callback, one that it calls for found values, and one for not found.
|
||||
*
|
||||
* Now we make it even uglier by doing this from multiple threads. The missing values are retreived from the
|
||||
* dictionary in a background thread, and this thread calls the provided callback. So if you provide the callbacks,
|
||||
* you MUST wait until the background update finishes, or god knows what happens. Unfortunately, we have no
|
||||
* way to check that you did this right, so good luck.
|
||||
*/
|
||||
struct UpdateUnit
|
||||
{
|
||||
UpdateUnit(std::vector<Key> requested_ids_,
|
||||
PresentIdHandler present_id_handler_,
|
||||
AbsentIdHandler absent_id_handler_) :
|
||||
requested_ids(std::move(requested_ids_)),
|
||||
present_id_handler(present_id_handler_),
|
||||
absent_id_handler(absent_id_handler_) {}
|
||||
|
||||
explicit UpdateUnit(std::vector<Key> requested_ids_) :
|
||||
requested_ids(std::move(requested_ids_)),
|
||||
present_id_handler([](Key, size_t){}),
|
||||
absent_id_handler([](Key, size_t){}) {}
|
||||
|
||||
std::vector<Key> requested_ids;
|
||||
PresentIdHandler present_id_handler;
|
||||
AbsentIdHandler absent_id_handler;
|
||||
|
||||
std::atomic<bool> is_done{false};
|
||||
std::exception_ptr current_exception{nullptr};
|
||||
};
|
||||
|
||||
using UpdateUnitPtr = std::shared_ptr<UpdateUnit>;
|
||||
using UpdateQueue = ConcurrentBoundedQueue<UpdateUnitPtr>;
|
||||
|
||||
|
||||
/*
|
||||
* This class is used to concatenate requested_keys.
|
||||
*
|
||||
* Imagine that we have several UpdateUnit with different vectors of keys and callbacks for that keys.
|
||||
* We concatenate them into a long vector of keys that looks like:
|
||||
*
|
||||
* a1...ak_a b1...bk_2 c1...ck_3,
|
||||
*
|
||||
* where a1...ak_a are requested_keys from the first UpdateUnit.
|
||||
* In addition we have the same number (three in this case) of callbacks.
|
||||
* This class helps us to find a callback (or many callbacks) for a special key.
|
||||
* */
|
||||
|
||||
class BunchUpdateUnit
|
||||
{
|
||||
public:
|
||||
explicit BunchUpdateUnit(std::vector<UpdateUnitPtr> & update_request)
|
||||
{
|
||||
/// Here we prepare total count of all requested ids
|
||||
/// not to do useless allocations later.
|
||||
size_t total_requested_keys_count = 0;
|
||||
|
||||
for (auto & unit_ptr: update_request)
|
||||
{
|
||||
total_requested_keys_count += unit_ptr->requested_ids.size();
|
||||
if (helper.empty())
|
||||
helper.push_back(unit_ptr->requested_ids.size());
|
||||
else
|
||||
helper.push_back(unit_ptr->requested_ids.size() + helper.back());
|
||||
present_id_handlers.emplace_back(unit_ptr->present_id_handler);
|
||||
absent_id_handlers.emplace_back(unit_ptr->absent_id_handler);
|
||||
}
|
||||
|
||||
concatenated_requested_ids.reserve(total_requested_keys_count);
|
||||
for (auto & unit_ptr: update_request)
|
||||
std::for_each(std::begin(unit_ptr->requested_ids), std::end(unit_ptr->requested_ids),
|
||||
[&] (const Key & key) {concatenated_requested_ids.push_back(key);});
|
||||
|
||||
}
|
||||
|
||||
const std::vector<Key> & getRequestedIds()
|
||||
{
|
||||
return concatenated_requested_ids;
|
||||
}
|
||||
|
||||
void informCallersAboutPresentId(Key id, size_t cell_idx)
|
||||
{
|
||||
for (size_t i = 0; i < concatenated_requested_ids.size(); ++i)
|
||||
{
|
||||
auto & curr = concatenated_requested_ids[i];
|
||||
if (curr == id)
|
||||
getPresentIdHandlerForPosition(i)(id, cell_idx);
|
||||
}
|
||||
}
|
||||
|
||||
void informCallersAboutAbsentId(Key id, size_t cell_idx)
|
||||
{
|
||||
for (size_t i = 0; i < concatenated_requested_ids.size(); ++i)
|
||||
if (concatenated_requested_ids[i] == id)
|
||||
getAbsentIdHandlerForPosition(i)(id, cell_idx);
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
PresentIdHandler & getPresentIdHandlerForPosition(size_t position)
|
||||
{
|
||||
return present_id_handlers[getUpdateUnitNumberForRequestedIdPosition(position)];
|
||||
}
|
||||
|
||||
AbsentIdHandler & getAbsentIdHandlerForPosition(size_t position)
|
||||
{
|
||||
return absent_id_handlers[getUpdateUnitNumberForRequestedIdPosition((position))];
|
||||
}
|
||||
|
||||
size_t getUpdateUnitNumberForRequestedIdPosition(size_t position)
|
||||
{
|
||||
return std::lower_bound(helper.begin(), helper.end(), position) - helper.begin();
|
||||
}
|
||||
|
||||
std::vector<Key> concatenated_requested_ids;
|
||||
std::vector<PresentIdHandler> present_id_handlers;
|
||||
std::vector<AbsentIdHandler> absent_id_handlers;
|
||||
|
||||
std::vector<size_t> helper;
|
||||
};
|
||||
|
||||
mutable UpdateQueue update_queue;
|
||||
|
||||
ThreadPool update_pool;
|
||||
|
||||
/*
|
||||
* Actually, we can divide all requested keys into two 'buckets'. There are only four possible states and they
|
||||
* are described in the table.
|
||||
*
|
||||
* cache_not_found_ids |0|0|1|1|
|
||||
* cache_expired_ids |0|1|0|1|
|
||||
*
|
||||
* 0 - if set is empty, 1 - otherwise
|
||||
*
|
||||
* Only if there are no cache_not_found_ids and some cache_expired_ids
|
||||
* (with allow_read_expired_keys_from_cache_dictionary setting) we can perform async update.
|
||||
* Otherwise we have no concatenate ids and update them sync.
|
||||
*
|
||||
*/
|
||||
void updateThreadFunction();
|
||||
void update(BunchUpdateUnit & bunch_update_unit) const;
|
||||
|
||||
|
||||
void tryPushToUpdateQueueOrThrow(UpdateUnitPtr & update_unit_ptr) const;
|
||||
void waitForCurrentUpdateFinish(UpdateUnitPtr & update_unit_ptr) const;
|
||||
|
||||
mutable std::mutex update_mutex;
|
||||
mutable std::condition_variable is_update_finished;
|
||||
|
||||
std::atomic<bool> finished{false};
|
||||
};
|
||||
}
|
||||
|
@ -40,11 +40,13 @@ void CacheDictionary::getItemsNumberImpl(
|
||||
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const
|
||||
{
|
||||
/// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
|
||||
std::unordered_map<Key, std::vector<size_t>> outdated_ids;
|
||||
std::unordered_map<Key, std::vector<size_t>> cache_expired_ids;
|
||||
std::unordered_map<Key, std::vector<size_t>> cache_not_found_ids;
|
||||
|
||||
auto & attribute_array = std::get<ContainerPtrType<AttributeType>>(attribute.arrays);
|
||||
const auto rows = ext::size(ids);
|
||||
|
||||
size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0;
|
||||
size_t cache_hit = 0;
|
||||
|
||||
{
|
||||
const ProfilingScopedReadRWLock read_lock{rw_lock, ProfileEvents::DictCacheLockReadNs};
|
||||
@ -61,52 +63,105 @@ void CacheDictionary::getItemsNumberImpl(
|
||||
* 3. explicit defaults were specified and cell was set default. */
|
||||
|
||||
const auto find_result = findCellIdx(id, now);
|
||||
|
||||
auto update_routine = [&]()
|
||||
{
|
||||
const auto & cell_idx = find_result.cell_idx;
|
||||
const auto & cell = cells[cell_idx];
|
||||
out[row] = cell.isDefault() ? get_default(row) : static_cast<OutputType>(attribute_array[cell_idx]);
|
||||
};
|
||||
|
||||
if (!find_result.valid)
|
||||
{
|
||||
outdated_ids[id].push_back(row);
|
||||
|
||||
if (find_result.outdated)
|
||||
++cache_expired;
|
||||
{
|
||||
cache_expired_ids[id].push_back(row);
|
||||
if (allow_read_expired_keys)
|
||||
update_routine();
|
||||
}
|
||||
else
|
||||
++cache_not_found;
|
||||
{
|
||||
cache_not_found_ids[id].push_back(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
++cache_hit;
|
||||
const auto & cell_idx = find_result.cell_idx;
|
||||
const auto & cell = cells[cell_idx];
|
||||
out[row] = cell.isDefault() ? get_default(row) : static_cast<OutputType>(attribute_array[cell_idx]);
|
||||
update_routine();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired);
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found);
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired_ids.size());
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found_ids.size());
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit);
|
||||
|
||||
query_count.fetch_add(rows, std::memory_order_relaxed);
|
||||
hit_count.fetch_add(rows - outdated_ids.size(), std::memory_order_release);
|
||||
hit_count.fetch_add(rows - cache_expired_ids.size() - cache_not_found_ids.size(), std::memory_order_release);
|
||||
|
||||
if (outdated_ids.empty())
|
||||
return;
|
||||
if (cache_not_found_ids.empty())
|
||||
{
|
||||
/// Nothing to update - return
|
||||
if (cache_expired_ids.empty())
|
||||
return;
|
||||
|
||||
std::vector<Key> required_ids(outdated_ids.size());
|
||||
std::transform(std::begin(outdated_ids), std::end(outdated_ids), std::begin(required_ids), [](auto & pair) { return pair.first; });
|
||||
|
||||
/// request new values
|
||||
update(
|
||||
required_ids,
|
||||
[&](const auto id, const auto cell_idx)
|
||||
/// Update async only if allow_read_expired_keys_is_enabledadd condvar usage and better code
|
||||
if (allow_read_expired_keys)
|
||||
{
|
||||
const auto attribute_value = attribute_array[cell_idx];
|
||||
std::vector<Key> required_expired_ids;
|
||||
required_expired_ids.reserve(cache_expired_ids.size());
|
||||
std::transform(std::begin(cache_expired_ids), std::end(cache_expired_ids), std::back_inserter(required_expired_ids),
|
||||
[](auto & pair) { return pair.first; });
|
||||
|
||||
for (const size_t row : outdated_ids[id])
|
||||
out[row] = static_cast<OutputType>(attribute_value);
|
||||
},
|
||||
[&](const auto id, const auto)
|
||||
{
|
||||
for (const size_t row : outdated_ids[id])
|
||||
out[row] = get_default(row);
|
||||
});
|
||||
/// request new values
|
||||
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_expired_ids);
|
||||
|
||||
tryPushToUpdateQueueOrThrow(update_unit_ptr);
|
||||
|
||||
/// Nothing to do - return
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// From this point we have to update all keys sync.
|
||||
/// Maybe allow_read_expired_keys_from_cache_dictionary is disabled
|
||||
/// and there no cache_not_found_ids but some cache_expired.
|
||||
|
||||
std::vector<Key> required_ids;
|
||||
required_ids.reserve(cache_not_found_ids.size() + cache_expired_ids.size());
|
||||
std::transform(
|
||||
std::begin(cache_not_found_ids), std::end(cache_not_found_ids),
|
||||
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
|
||||
std::transform(
|
||||
std::begin(cache_expired_ids), std::end(cache_expired_ids),
|
||||
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
|
||||
|
||||
auto on_cell_updated = [&] (const auto id, const auto cell_idx)
|
||||
{
|
||||
const auto attribute_value = attribute_array[cell_idx];
|
||||
|
||||
for (const size_t row : cache_not_found_ids[id])
|
||||
out[row] = static_cast<OutputType>(attribute_value);
|
||||
|
||||
for (const size_t row : cache_expired_ids[id])
|
||||
out[row] = static_cast<OutputType>(attribute_value);
|
||||
};
|
||||
|
||||
auto on_id_not_found = [&] (const auto id, const auto)
|
||||
{
|
||||
for (const size_t row : cache_not_found_ids[id])
|
||||
out[row] = get_default(row);
|
||||
|
||||
for (const size_t row : cache_expired_ids[id])
|
||||
out[row] = get_default(row);
|
||||
};
|
||||
|
||||
/// Request new values
|
||||
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_ids, on_cell_updated, on_id_not_found);
|
||||
|
||||
tryPushToUpdateQueueOrThrow(update_unit_ptr);
|
||||
waitForCurrentUpdateFinish(update_unit_ptr);
|
||||
}
|
||||
|
||||
template <typename DefaultGetter>
|
||||
@ -161,12 +216,13 @@ void CacheDictionary::getItemsString(
|
||||
out->getOffsets().resize_assume_reserved(0);
|
||||
|
||||
/// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
|
||||
std::unordered_map<Key, std::vector<size_t>> outdated_ids;
|
||||
std::unordered_map<Key, std::vector<size_t>> cache_expired_ids;
|
||||
std::unordered_map<Key, std::vector<size_t>> cache_not_found_ids;
|
||||
/// we are going to store every string separately
|
||||
std::unordered_map<Key, String> map;
|
||||
|
||||
size_t total_length = 0;
|
||||
size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0;
|
||||
size_t cache_hit = 0;
|
||||
{
|
||||
const ProfilingScopedReadRWLock read_lock{rw_lock, ProfileEvents::DictCacheLockReadNs};
|
||||
|
||||
@ -176,17 +232,10 @@ void CacheDictionary::getItemsString(
|
||||
const auto id = ids[row];
|
||||
|
||||
const auto find_result = findCellIdx(id, now);
|
||||
if (!find_result.valid)
|
||||
|
||||
|
||||
auto insert_value_routine = [&]()
|
||||
{
|
||||
outdated_ids[id].push_back(row);
|
||||
if (find_result.outdated)
|
||||
++cache_expired;
|
||||
else
|
||||
++cache_not_found;
|
||||
}
|
||||
else
|
||||
{
|
||||
++cache_hit;
|
||||
const auto & cell_idx = find_result.cell_idx;
|
||||
const auto & cell = cells[cell_idx];
|
||||
const auto string_ref = cell.isDefault() ? get_default(row) : attribute_array[cell_idx];
|
||||
@ -195,37 +244,82 @@ void CacheDictionary::getItemsString(
|
||||
map[id] = String{string_ref};
|
||||
|
||||
total_length += string_ref.size + 1;
|
||||
};
|
||||
|
||||
if (!find_result.valid)
|
||||
{
|
||||
if (find_result.outdated)
|
||||
{
|
||||
cache_expired_ids[id].push_back(row);
|
||||
|
||||
if (allow_read_expired_keys)
|
||||
insert_value_routine();
|
||||
} else
|
||||
cache_not_found_ids[id].push_back(row);
|
||||
} else
|
||||
{
|
||||
++cache_hit;
|
||||
insert_value_routine();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired);
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found);
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired_ids.size());
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found_ids.size());
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit);
|
||||
|
||||
query_count.fetch_add(rows, std::memory_order_relaxed);
|
||||
hit_count.fetch_add(rows - outdated_ids.size(), std::memory_order_release);
|
||||
hit_count.fetch_add(rows - cache_expired_ids.size() - cache_not_found_ids.size(), std::memory_order_release);
|
||||
|
||||
/// request new values
|
||||
if (!outdated_ids.empty())
|
||||
/// Async update of expired keys.
|
||||
if (cache_not_found_ids.empty())
|
||||
{
|
||||
std::vector<Key> required_ids(outdated_ids.size());
|
||||
std::transform(std::begin(outdated_ids), std::end(outdated_ids), std::begin(required_ids), [](auto & pair) { return pair.first; });
|
||||
if (allow_read_expired_keys && !cache_expired_ids.empty())
|
||||
{
|
||||
std::vector<Key> required_expired_ids;
|
||||
required_expired_ids.reserve(cache_not_found_ids.size());
|
||||
std::transform(std::begin(cache_expired_ids), std::end(cache_expired_ids),
|
||||
std::back_inserter(required_expired_ids), [](auto & pair) { return pair.first; });
|
||||
|
||||
update(
|
||||
required_ids,
|
||||
[&](const auto id, const auto cell_idx)
|
||||
{
|
||||
const auto attribute_value = attribute_array[cell_idx];
|
||||
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_expired_ids);
|
||||
|
||||
map[id] = String{attribute_value};
|
||||
total_length += (attribute_value.size + 1) * outdated_ids[id].size();
|
||||
},
|
||||
[&](const auto id, const auto)
|
||||
{
|
||||
for (const auto row : outdated_ids[id])
|
||||
total_length += get_default(row).size + 1;
|
||||
});
|
||||
tryPushToUpdateQueueOrThrow(update_unit_ptr);
|
||||
|
||||
/// Do not return at this point, because there some extra stuff to do at the end of this method.
|
||||
}
|
||||
}
|
||||
|
||||
/// Request new values sync.
|
||||
/// We have request both cache_not_found_ids and cache_expired_ids.
|
||||
if (!cache_not_found_ids.empty())
|
||||
{
|
||||
std::vector<Key> required_ids;
|
||||
required_ids.reserve(cache_not_found_ids.size() + cache_expired_ids.size());
|
||||
std::transform(
|
||||
std::begin(cache_not_found_ids), std::end(cache_not_found_ids),
|
||||
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
|
||||
std::transform(
|
||||
std::begin(cache_expired_ids), std::end(cache_expired_ids),
|
||||
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
|
||||
|
||||
auto on_cell_updated = [&] (const auto id, const auto cell_idx)
|
||||
{
|
||||
const auto attribute_value = attribute_array[cell_idx];
|
||||
|
||||
map[id] = String{attribute_value};
|
||||
total_length += (attribute_value.size + 1) * cache_not_found_ids[id].size();
|
||||
};
|
||||
|
||||
auto on_id_not_found = [&] (const auto id, const auto)
|
||||
{
|
||||
for (const auto row : cache_not_found_ids[id])
|
||||
total_length += get_default(row).size + 1;
|
||||
};
|
||||
|
||||
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_ids, on_cell_updated, on_id_not_found);
|
||||
|
||||
tryPushToUpdateQueueOrThrow(update_unit_ptr);
|
||||
waitForCurrentUpdateFinish(update_unit_ptr);
|
||||
}
|
||||
|
||||
out->getChars().reserve(total_length);
|
||||
@ -240,167 +334,4 @@ void CacheDictionary::getItemsString(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PresentIdHandler, typename AbsentIdHandler>
|
||||
void CacheDictionary::update(
|
||||
const std::vector<Key> & requested_ids, PresentIdHandler && on_cell_updated, AbsentIdHandler && on_id_not_found) const
|
||||
{
|
||||
CurrentMetrics::Increment metric_increment{CurrentMetrics::DictCacheRequests};
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequested, requested_ids.size());
|
||||
|
||||
std::unordered_map<Key, UInt8> remaining_ids{requested_ids.size()};
|
||||
for (const auto id : requested_ids)
|
||||
remaining_ids.insert({id, 0});
|
||||
|
||||
const auto now = std::chrono::system_clock::now();
|
||||
|
||||
const ProfilingScopedWriteRWLock write_lock{rw_lock, ProfileEvents::DictCacheLockWriteNs};
|
||||
|
||||
if (now > backoff_end_time)
|
||||
{
|
||||
try
|
||||
{
|
||||
if (error_count)
|
||||
{
|
||||
/// Recover after error: we have to clone the source here because
|
||||
/// it could keep connections which should be reset after error.
|
||||
source_ptr = source_ptr->clone();
|
||||
}
|
||||
|
||||
Stopwatch watch;
|
||||
auto stream = source_ptr->loadIds(requested_ids);
|
||||
stream->readPrefix();
|
||||
|
||||
while (const auto block = stream->read())
|
||||
{
|
||||
const auto id_column = typeid_cast<const ColumnUInt64 *>(block.safeGetByPosition(0).column.get());
|
||||
if (!id_column)
|
||||
throw Exception{name + ": id column has type different from UInt64.", ErrorCodes::TYPE_MISMATCH};
|
||||
|
||||
const auto & ids = id_column->getData();
|
||||
|
||||
/// cache column pointers
|
||||
const auto column_ptrs = ext::map<std::vector>(
|
||||
ext::range(0, attributes.size()), [&block](size_t i) { return block.safeGetByPosition(i + 1).column.get(); });
|
||||
|
||||
for (const auto i : ext::range(0, ids.size()))
|
||||
{
|
||||
const auto id = ids[i];
|
||||
|
||||
const auto find_result = findCellIdx(id, now);
|
||||
const auto & cell_idx = find_result.cell_idx;
|
||||
|
||||
auto & cell = cells[cell_idx];
|
||||
|
||||
for (const auto attribute_idx : ext::range(0, attributes.size()))
|
||||
{
|
||||
const auto & attribute_column = *column_ptrs[attribute_idx];
|
||||
auto & attribute = attributes[attribute_idx];
|
||||
|
||||
setAttributeValue(attribute, cell_idx, attribute_column[i]);
|
||||
}
|
||||
|
||||
/// if cell id is zero and zero does not map to this cell, then the cell is unused
|
||||
if (cell.id == 0 && cell_idx != zero_cell_idx)
|
||||
element_count.fetch_add(1, std::memory_order_relaxed);
|
||||
|
||||
cell.id = id;
|
||||
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
|
||||
{
|
||||
std::uniform_int_distribution<UInt64> distribution{dict_lifetime.min_sec, dict_lifetime.max_sec};
|
||||
cell.setExpiresAt(now + std::chrono::seconds{distribution(rnd_engine)});
|
||||
}
|
||||
else
|
||||
cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
|
||||
|
||||
/// inform caller
|
||||
on_cell_updated(id, cell_idx);
|
||||
/// mark corresponding id as found
|
||||
remaining_ids[id] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
stream->readSuffix();
|
||||
|
||||
error_count = 0;
|
||||
last_exception = std::exception_ptr{};
|
||||
backoff_end_time = std::chrono::system_clock::time_point{};
|
||||
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheRequestTimeNs, watch.elapsed());
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
++error_count;
|
||||
last_exception = std::current_exception();
|
||||
backoff_end_time = now + std::chrono::seconds(calculateDurationWithBackoff(rnd_engine, error_count));
|
||||
|
||||
tryLogException(last_exception, log, "Could not update cache dictionary '" + getFullName() +
|
||||
"', next update is scheduled at " + ext::to_string(backoff_end_time));
|
||||
}
|
||||
}
|
||||
|
||||
size_t not_found_num = 0, found_num = 0;
|
||||
|
||||
/// Check which ids have not been found and require setting null_value
|
||||
for (const auto & id_found_pair : remaining_ids)
|
||||
{
|
||||
if (id_found_pair.second)
|
||||
{
|
||||
++found_num;
|
||||
continue;
|
||||
}
|
||||
++not_found_num;
|
||||
|
||||
const auto id = id_found_pair.first;
|
||||
|
||||
const auto find_result = findCellIdx(id, now);
|
||||
const auto & cell_idx = find_result.cell_idx;
|
||||
auto & cell = cells[cell_idx];
|
||||
|
||||
if (error_count)
|
||||
{
|
||||
if (find_result.outdated)
|
||||
{
|
||||
/// We have expired data for that `id` so we can continue using it.
|
||||
bool was_default = cell.isDefault();
|
||||
cell.setExpiresAt(backoff_end_time);
|
||||
if (was_default)
|
||||
cell.setDefault();
|
||||
if (was_default)
|
||||
on_id_not_found(id, cell_idx);
|
||||
else
|
||||
on_cell_updated(id, cell_idx);
|
||||
continue;
|
||||
}
|
||||
/// We don't have expired data for that `id` so all we can do is to rethrow `last_exception`.
|
||||
std::rethrow_exception(last_exception);
|
||||
}
|
||||
|
||||
/// Check if cell had not been occupied before and increment element counter if it hadn't
|
||||
if (cell.id == 0 && cell_idx != zero_cell_idx)
|
||||
element_count.fetch_add(1, std::memory_order_relaxed);
|
||||
|
||||
cell.id = id;
|
||||
|
||||
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
|
||||
{
|
||||
std::uniform_int_distribution<UInt64> distribution{dict_lifetime.min_sec, dict_lifetime.max_sec};
|
||||
cell.setExpiresAt(now + std::chrono::seconds{distribution(rnd_engine)});
|
||||
}
|
||||
else
|
||||
cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
|
||||
|
||||
/// Set null_value for each attribute
|
||||
cell.setDefault();
|
||||
for (auto & attribute : attributes)
|
||||
setDefaultAttributeValue(attribute, cell_idx);
|
||||
|
||||
/// inform caller that the cell has not been found
|
||||
on_id_not_found(id, cell_idx);
|
||||
}
|
||||
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequestedMiss, not_found_num);
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequestedFound, found_num);
|
||||
ProfileEvents::increment(ProfileEvents::DictCacheRequests);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -97,6 +97,7 @@ static FormatSettings getOutputFormatSetting(const Settings & settings, const Co
|
||||
format_settings.template_settings.resultset_format = settings.format_template_resultset;
|
||||
format_settings.template_settings.row_format = settings.format_template_row;
|
||||
format_settings.template_settings.row_between_delimiter = settings.format_template_rows_between_delimiter;
|
||||
format_settings.tsv.crlf_end_of_line = settings.output_format_tsv_crlf_end_of_line;
|
||||
format_settings.write_statistics = settings.output_format_write_statistics;
|
||||
format_settings.parquet.row_group_size = settings.output_format_parquet_row_group_size;
|
||||
format_settings.schema.format_schema = settings.format_schema;
|
||||
@ -144,9 +145,19 @@ BlockInputStreamPtr FormatFactory::getInput(
|
||||
|
||||
// Doesn't make sense to use parallel parsing with less than four threads
|
||||
// (segmentator + two parsers + reader).
|
||||
if (settings.input_format_parallel_parsing
|
||||
&& file_segmentation_engine
|
||||
&& settings.max_threads >= 4)
|
||||
bool parallel_parsing = settings.input_format_parallel_parsing && file_segmentation_engine && settings.max_threads >= 4;
|
||||
|
||||
if (parallel_parsing && name == "JSONEachRow")
|
||||
{
|
||||
/// FIXME ParallelParsingBlockInputStream doesn't support formats with non-trivial readPrefix() and readSuffix()
|
||||
|
||||
/// For JSONEachRow we can safely skip whitespace characters
|
||||
skipWhitespaceIfAny(buf);
|
||||
if (buf.eof() || *buf.position() == '[')
|
||||
parallel_parsing = false; /// Disable it for JSONEachRow if data is in square brackets (see JSONEachRowRowInputFormat)
|
||||
}
|
||||
|
||||
if (parallel_parsing)
|
||||
{
|
||||
const auto & input_getter = getCreators(name).input_processor_creator;
|
||||
if (!input_getter)
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <Core/Types.h>
|
||||
#include <Columns/IColumn.h>
|
||||
#include <DataStreams/IBlockStream_fwd.h>
|
||||
#include <IO/BufferWithOwnMemory.h>
|
||||
|
||||
@ -9,7 +10,6 @@
|
||||
#include <unordered_map>
|
||||
#include <boost/noncopyable.hpp>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
@ -53,7 +53,9 @@ public:
|
||||
|
||||
/// This callback allows to perform some additional actions after writing a single row.
|
||||
/// It's initial purpose was to flush Kafka message for each row.
|
||||
using WriteCallback = std::function<void()>;
|
||||
using WriteCallback = std::function<void(
|
||||
const Columns & columns,
|
||||
size_t row)>;
|
||||
|
||||
private:
|
||||
using InputCreator = std::function<BlockInputStreamPtr(
|
||||
|
@ -64,6 +64,7 @@ struct FormatSettings
|
||||
struct TSV
|
||||
{
|
||||
bool empty_as_default = false;
|
||||
bool crlf_end_of_line = false;
|
||||
};
|
||||
|
||||
TSV tsv;
|
||||
|
@ -45,7 +45,7 @@ try
|
||||
BlockInputStreamPtr block_input = std::make_shared<InputStreamFromInputFormat>(std::move(input_format));
|
||||
|
||||
BlockOutputStreamPtr block_output = std::make_shared<OutputStreamToOutputFormat>(
|
||||
std::make_shared<TabSeparatedRowOutputFormat>(out_buf, sample, false, false, [] {}, format_settings));
|
||||
std::make_shared<TabSeparatedRowOutputFormat>(out_buf, sample, false, false, [](const Columns & /* columns */, size_t /* row */){}, format_settings));
|
||||
|
||||
copyData(*block_input, *block_output);
|
||||
return 0;
|
||||
|
@ -35,7 +35,6 @@
|
||||
#include <Columns/ColumnsCommon.h>
|
||||
#include <Common/FieldVisitors.h>
|
||||
#include <Common/assert_cast.h>
|
||||
#include <Interpreters/ExpressionActions.h>
|
||||
#include <Functions/IFunctionAdaptors.h>
|
||||
#include <Functions/FunctionsMiscellaneous.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
|
66
dbms/src/Functions/blockSerializedSize.cpp
Normal file
66
dbms/src/Functions/blockSerializedSize.cpp
Normal file
@ -0,0 +1,66 @@
|
||||
#include <Functions/IFunctionImpl.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <IO/NullWriteBuffer.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Returns size on disk for *block* (without taking into account compression).
|
||||
class FunctionBlockSerializedSize : public IFunction
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "blockSerializedSize";
|
||||
|
||||
static FunctionPtr create(const Context &)
|
||||
{
|
||||
return std::make_shared<FunctionBlockSerializedSize>();
|
||||
}
|
||||
|
||||
String getName() const override { return name; }
|
||||
bool useDefaultImplementationForNulls() const override { return false; }
|
||||
size_t getNumberOfArguments() const override { return 0; }
|
||||
bool isVariadic() const override { return true; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const override
|
||||
{
|
||||
return std::make_shared<DataTypeUInt64>();
|
||||
}
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
|
||||
{
|
||||
UInt64 size = 0;
|
||||
|
||||
for (size_t i = 0; i < arguments.size(); ++i)
|
||||
size += blockSerializedSizeOne(block.getByPosition(arguments[i]));
|
||||
|
||||
block.getByPosition(result).column = DataTypeUInt64().createColumnConst(
|
||||
input_rows_count, size)->convertToFullColumnIfConst();
|
||||
}
|
||||
|
||||
UInt64 blockSerializedSizeOne(const ColumnWithTypeAndName & elem) const
|
||||
{
|
||||
ColumnPtr full_column = elem.column->convertToFullColumnIfConst();
|
||||
|
||||
IDataType::SerializeBinaryBulkSettings settings;
|
||||
NullWriteBuffer out;
|
||||
|
||||
settings.getter = [&out](IDataType::SubstreamPath) -> WriteBuffer * { return &out; };
|
||||
|
||||
IDataType::SerializeBinaryBulkStatePtr state;
|
||||
elem.type->serializeBinaryBulkWithMultipleStreams(*full_column,
|
||||
0 /** offset */, 0 /** limit */,
|
||||
settings, state);
|
||||
|
||||
return out.count();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
void registerFunctionBlockSerializedSize(FunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction<FunctionBlockSerializedSize>();
|
||||
}
|
||||
|
||||
}
|
@ -73,9 +73,15 @@ public:
|
||||
for (size_t i = 0; i < input_rows_count; ++i)
|
||||
{
|
||||
StringRef source = column_concrete->getDataAt(i);
|
||||
int status = 0;
|
||||
std::string demangled = demangle(source.data, status);
|
||||
result_column->insertDataWithTerminatingZero(demangled.data(), demangled.size() + 1);
|
||||
auto demangled = tryDemangle(source.data);
|
||||
if (demangled)
|
||||
{
|
||||
result_column->insertDataWithTerminatingZero(demangled.get(), strlen(demangled.get()));
|
||||
}
|
||||
else
|
||||
{
|
||||
result_column->insertDataWithTerminatingZero(source.data, source.size);
|
||||
}
|
||||
}
|
||||
|
||||
block.getByPosition(result).column = std::move(result_column);
|
||||
|
@ -14,6 +14,7 @@ void registerFunctionFQDN(FunctionFactory &);
|
||||
void registerFunctionVisibleWidth(FunctionFactory &);
|
||||
void registerFunctionToTypeName(FunctionFactory &);
|
||||
void registerFunctionGetSizeOfEnumType(FunctionFactory &);
|
||||
void registerFunctionBlockSerializedSize(FunctionFactory &);
|
||||
void registerFunctionToColumnTypeName(FunctionFactory &);
|
||||
void registerFunctionDumpColumnStructure(FunctionFactory &);
|
||||
void registerFunctionDefaultValueOfArgumentType(FunctionFactory &);
|
||||
@ -72,6 +73,7 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory)
|
||||
registerFunctionVisibleWidth(factory);
|
||||
registerFunctionToTypeName(factory);
|
||||
registerFunctionGetSizeOfEnumType(factory);
|
||||
registerFunctionBlockSerializedSize(factory);
|
||||
registerFunctionToColumnTypeName(factory);
|
||||
registerFunctionDumpColumnStructure(factory);
|
||||
registerFunctionDefaultValueOfArgumentType(factory);
|
||||
|
@ -16,6 +16,7 @@ struct ConnectionTimeouts
|
||||
Poco::Timespan receive_timeout;
|
||||
Poco::Timespan tcp_keep_alive_timeout;
|
||||
Poco::Timespan http_keep_alive_timeout;
|
||||
Poco::Timespan secure_connection_timeout;
|
||||
|
||||
ConnectionTimeouts() = default;
|
||||
|
||||
@ -26,7 +27,8 @@ struct ConnectionTimeouts
|
||||
send_timeout(send_timeout_),
|
||||
receive_timeout(receive_timeout_),
|
||||
tcp_keep_alive_timeout(0),
|
||||
http_keep_alive_timeout(0)
|
||||
http_keep_alive_timeout(0),
|
||||
secure_connection_timeout(connection_timeout)
|
||||
{
|
||||
}
|
||||
|
||||
@ -38,7 +40,8 @@ struct ConnectionTimeouts
|
||||
send_timeout(send_timeout_),
|
||||
receive_timeout(receive_timeout_),
|
||||
tcp_keep_alive_timeout(tcp_keep_alive_timeout_),
|
||||
http_keep_alive_timeout(0)
|
||||
http_keep_alive_timeout(0),
|
||||
secure_connection_timeout(connection_timeout)
|
||||
{
|
||||
}
|
||||
ConnectionTimeouts(const Poco::Timespan & connection_timeout_,
|
||||
@ -50,10 +53,25 @@ struct ConnectionTimeouts
|
||||
send_timeout(send_timeout_),
|
||||
receive_timeout(receive_timeout_),
|
||||
tcp_keep_alive_timeout(tcp_keep_alive_timeout_),
|
||||
http_keep_alive_timeout(http_keep_alive_timeout_)
|
||||
http_keep_alive_timeout(http_keep_alive_timeout_),
|
||||
secure_connection_timeout(connection_timeout)
|
||||
{
|
||||
}
|
||||
|
||||
ConnectionTimeouts(const Poco::Timespan & connection_timeout_,
|
||||
const Poco::Timespan & send_timeout_,
|
||||
const Poco::Timespan & receive_timeout_,
|
||||
const Poco::Timespan & tcp_keep_alive_timeout_,
|
||||
const Poco::Timespan & http_keep_alive_timeout_,
|
||||
const Poco::Timespan & secure_connection_timeout_)
|
||||
: connection_timeout(connection_timeout_),
|
||||
send_timeout(send_timeout_),
|
||||
receive_timeout(receive_timeout_),
|
||||
tcp_keep_alive_timeout(tcp_keep_alive_timeout_),
|
||||
http_keep_alive_timeout(http_keep_alive_timeout_),
|
||||
secure_connection_timeout(secure_connection_timeout_)
|
||||
{
|
||||
}
|
||||
|
||||
static Poco::Timespan saturate(const Poco::Timespan & timespan, const Poco::Timespan & limit)
|
||||
{
|
||||
@ -69,7 +87,8 @@ struct ConnectionTimeouts
|
||||
saturate(send_timeout, limit),
|
||||
saturate(receive_timeout, limit),
|
||||
saturate(tcp_keep_alive_timeout, limit),
|
||||
saturate(http_keep_alive_timeout, limit));
|
||||
saturate(http_keep_alive_timeout, limit),
|
||||
saturate(secure_connection_timeout, limit));
|
||||
}
|
||||
|
||||
/// Timeouts for the case when we have just single attempt to connect.
|
||||
@ -81,7 +100,7 @@ struct ConnectionTimeouts
|
||||
/// Timeouts for the case when we will try many addresses in a loop.
|
||||
static ConnectionTimeouts getTCPTimeoutsWithFailover(const Settings & settings)
|
||||
{
|
||||
return ConnectionTimeouts(settings.connect_timeout_with_failover_ms, settings.send_timeout, settings.receive_timeout, settings.tcp_keep_alive_timeout);
|
||||
return ConnectionTimeouts(settings.connect_timeout_with_failover_ms, settings.send_timeout, settings.receive_timeout, settings.tcp_keep_alive_timeout, 0, settings.connect_timeout_with_failover_secure_ms);
|
||||
}
|
||||
|
||||
static ConnectionTimeouts getHTTPTimeouts(const Context & context)
|
||||
|
16
dbms/src/IO/NullWriteBuffer.cpp
Normal file
16
dbms/src/IO/NullWriteBuffer.cpp
Normal file
@ -0,0 +1,16 @@
|
||||
#include <IO/NullWriteBuffer.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
NullWriteBuffer::NullWriteBuffer(size_t buf_size, char * existing_memory, size_t alignment)
|
||||
: BufferWithOwnMemory<WriteBuffer>(buf_size, existing_memory, alignment)
|
||||
{
|
||||
}
|
||||
|
||||
void NullWriteBuffer::nextImpl()
|
||||
{
|
||||
}
|
||||
|
||||
}
|
18
dbms/src/IO/NullWriteBuffer.h
Normal file
18
dbms/src/IO/NullWriteBuffer.h
Normal file
@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include <IO/WriteBuffer.h>
|
||||
#include <IO/BufferWithOwnMemory.h>
|
||||
#include <boost/noncopyable.hpp>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Simply do nothing, can be used to measure amount of written bytes.
|
||||
class NullWriteBuffer : public BufferWithOwnMemory<WriteBuffer>, boost::noncopyable
|
||||
{
|
||||
public:
|
||||
NullWriteBuffer(size_t buf_size = 16<<10, char * existing_memory = nullptr, size_t alignment = false);
|
||||
void nextImpl() override;
|
||||
};
|
||||
|
||||
}
|
@ -6,17 +6,28 @@
|
||||
#include <port/unistd.h>
|
||||
#include <IO/ReadBufferAIO.h>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
namespace
|
||||
{
|
||||
std::string createTmpFileForEOFtest()
|
||||
{
|
||||
char pattern[] = "/tmp/fileXXXXXX";
|
||||
char * dir = ::mkdtemp(pattern);
|
||||
return std::string(dir) + "/foo";
|
||||
if (char * dir = ::mkdtemp(pattern); dir)
|
||||
{
|
||||
return std::string(dir) + "/foo";
|
||||
}
|
||||
else
|
||||
{
|
||||
/// We have no tmp in docker
|
||||
/// So we have to use root
|
||||
std::string almost_rand_dir = std::string{"/"} + std::to_string(rand()) + "foo";
|
||||
return almost_rand_dir;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void prepare_for_eof(std::string & filename, std::string & buf)
|
||||
void prepareForEOF(std::string & filename, std::string & buf)
|
||||
{
|
||||
static const std::string symbols = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
|
||||
|
||||
@ -28,7 +39,7 @@ void prepare_for_eof(std::string & filename, std::string & buf)
|
||||
for (size_t i = 0; i < n; ++i)
|
||||
buf += symbols[i % symbols.length()];
|
||||
|
||||
std::ofstream out(filename.c_str());
|
||||
std::ofstream out(filename);
|
||||
out << buf;
|
||||
}
|
||||
|
||||
@ -39,7 +50,7 @@ TEST(ReadBufferAIOTest, TestReadAfterAIO)
|
||||
using namespace DB;
|
||||
std::string data;
|
||||
std::string file_path;
|
||||
prepare_for_eof(file_path, data);
|
||||
prepareForEOF(file_path, data);
|
||||
ReadBufferAIO testbuf(file_path);
|
||||
|
||||
std::string newdata;
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include "Common/quoteString.h"
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Common/PODArray.h>
|
||||
#include <Core/Row.h>
|
||||
@ -334,7 +335,7 @@ void ActionsMatcher::visit(const ASTIdentifier & identifier, const ASTPtr & ast,
|
||||
found = true;
|
||||
|
||||
if (found)
|
||||
throw Exception("Column " + column_name.get(ast) + " is not under aggregate function and not in GROUP BY.",
|
||||
throw Exception("Column " + backQuote(column_name.get(ast)) + " is not under aggregate function and not in GROUP BY",
|
||||
ErrorCodes::NOT_AN_AGGREGATE);
|
||||
|
||||
/// Special check for WITH statement alias. Add alias action to be able to use this alias.
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include <Parsers/IAST.h>
|
||||
#include <Interpreters/PreparedSets.h>
|
||||
#include <Interpreters/ExpressionActions.h>
|
||||
#include <Interpreters/SubqueryForSet.h>
|
||||
#include <Interpreters/InDepthNodeVisitor.h>
|
||||
|
||||
@ -13,6 +12,9 @@ namespace DB
|
||||
class Context;
|
||||
class ASTFunction;
|
||||
|
||||
class ExpressionActions;
|
||||
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
|
||||
|
||||
/// The case of an explicit enumeration of values.
|
||||
SetPtr makeExplicitSet(
|
||||
const ASTFunction * node, const Block & sample_block, bool create_ordered_set,
|
||||
|
169
dbms/src/Interpreters/ArrayJoinAction.cpp
Normal file
169
dbms/src/Interpreters/ArrayJoinAction.cpp
Normal file
@ -0,0 +1,169 @@
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Interpreters/ArrayJoinAction.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
|
||||
extern const int TYPE_MISMATCH;
|
||||
}
|
||||
|
||||
ArrayJoinAction::ArrayJoinAction(const NameSet & array_joined_columns_, bool array_join_is_left, const Context & context)
|
||||
: columns(array_joined_columns_)
|
||||
, is_left(array_join_is_left)
|
||||
, is_unaligned(context.getSettingsRef().enable_unaligned_array_join)
|
||||
{
|
||||
if (columns.empty())
|
||||
throw Exception("No arrays to join", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
if (is_unaligned)
|
||||
{
|
||||
function_length = FunctionFactory::instance().get("length", context);
|
||||
function_greatest = FunctionFactory::instance().get("greatest", context);
|
||||
function_arrayResize = FunctionFactory::instance().get("arrayResize", context);
|
||||
}
|
||||
else if (is_left)
|
||||
function_builder = FunctionFactory::instance().get("emptyArrayToSingle", context);
|
||||
}
|
||||
|
||||
|
||||
void ArrayJoinAction::prepare(Block & sample_block)
|
||||
{
|
||||
for (const auto & name : columns)
|
||||
{
|
||||
ColumnWithTypeAndName & current = sample_block.getByName(name);
|
||||
const DataTypeArray * array_type = typeid_cast<const DataTypeArray *>(&*current.type);
|
||||
if (!array_type)
|
||||
throw Exception("ARRAY JOIN requires array argument", ErrorCodes::TYPE_MISMATCH);
|
||||
current.type = array_type->getNestedType();
|
||||
current.column = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void ArrayJoinAction::execute(Block & block, bool dry_run)
|
||||
{
|
||||
if (columns.empty())
|
||||
throw Exception("No arrays to join", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
ColumnPtr any_array_ptr = block.getByName(*columns.begin()).column->convertToFullColumnIfConst();
|
||||
const ColumnArray * any_array = typeid_cast<const ColumnArray *>(&*any_array_ptr);
|
||||
if (!any_array)
|
||||
throw Exception("ARRAY JOIN of not array: " + *columns.begin(), ErrorCodes::TYPE_MISMATCH);
|
||||
|
||||
/// If LEFT ARRAY JOIN, then we create columns in which empty arrays are replaced by arrays with one element - the default value.
|
||||
std::map<String, ColumnPtr> non_empty_array_columns;
|
||||
|
||||
if (is_unaligned)
|
||||
{
|
||||
/// Resize all array joined columns to the longest one, (at least 1 if LEFT ARRAY JOIN), padded with default values.
|
||||
auto rows = block.rows();
|
||||
auto uint64 = std::make_shared<DataTypeUInt64>();
|
||||
ColumnWithTypeAndName column_of_max_length;
|
||||
if (is_left)
|
||||
column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 1u), uint64, {});
|
||||
else
|
||||
column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 0u), uint64, {});
|
||||
|
||||
for (const auto & name : columns)
|
||||
{
|
||||
auto & src_col = block.getByName(name);
|
||||
|
||||
Block tmp_block{src_col, {{}, uint64, {}}};
|
||||
function_length->build({src_col})->execute(tmp_block, {0}, 1, rows);
|
||||
|
||||
Block tmp_block2{
|
||||
column_of_max_length, tmp_block.safeGetByPosition(1), {{}, uint64, {}}};
|
||||
function_greatest->build({column_of_max_length, tmp_block.safeGetByPosition(1)})->execute(tmp_block2, {0, 1}, 2, rows);
|
||||
column_of_max_length = tmp_block2.safeGetByPosition(2);
|
||||
}
|
||||
|
||||
for (const auto & name : columns)
|
||||
{
|
||||
auto & src_col = block.getByName(name);
|
||||
|
||||
Block tmp_block{src_col, column_of_max_length, {{}, src_col.type, {}}};
|
||||
function_arrayResize->build({src_col, column_of_max_length})->execute(tmp_block, {0, 1}, 2, rows);
|
||||
src_col.column = tmp_block.safeGetByPosition(2).column;
|
||||
any_array_ptr = src_col.column->convertToFullColumnIfConst();
|
||||
}
|
||||
|
||||
any_array = typeid_cast<const ColumnArray *>(&*any_array_ptr);
|
||||
}
|
||||
else if (is_left)
|
||||
{
|
||||
for (const auto & name : columns)
|
||||
{
|
||||
auto src_col = block.getByName(name);
|
||||
|
||||
Block tmp_block{src_col, {{}, src_col.type, {}}};
|
||||
|
||||
function_builder->build({src_col})->execute(tmp_block, {0}, 1, src_col.column->size(), dry_run);
|
||||
non_empty_array_columns[name] = tmp_block.safeGetByPosition(1).column;
|
||||
}
|
||||
|
||||
any_array_ptr = non_empty_array_columns.begin()->second->convertToFullColumnIfConst();
|
||||
any_array = &typeid_cast<const ColumnArray &>(*any_array_ptr);
|
||||
}
|
||||
|
||||
size_t num_columns = block.columns();
|
||||
for (size_t i = 0; i < num_columns; ++i)
|
||||
{
|
||||
ColumnWithTypeAndName & current = block.safeGetByPosition(i);
|
||||
|
||||
if (columns.count(current.name))
|
||||
{
|
||||
if (!typeid_cast<const DataTypeArray *>(&*current.type))
|
||||
throw Exception("ARRAY JOIN of not array: " + current.name, ErrorCodes::TYPE_MISMATCH);
|
||||
|
||||
ColumnPtr array_ptr = (is_left && !is_unaligned) ? non_empty_array_columns[current.name] : current.column;
|
||||
array_ptr = array_ptr->convertToFullColumnIfConst();
|
||||
|
||||
const ColumnArray & array = typeid_cast<const ColumnArray &>(*array_ptr);
|
||||
if (!is_unaligned && !array.hasEqualOffsets(typeid_cast<const ColumnArray &>(*any_array_ptr)))
|
||||
throw Exception("Sizes of ARRAY-JOIN-ed arrays do not match", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
|
||||
|
||||
current.column = typeid_cast<const ColumnArray &>(*array_ptr).getDataPtr();
|
||||
current.type = typeid_cast<const DataTypeArray &>(*current.type).getNestedType();
|
||||
}
|
||||
else
|
||||
{
|
||||
current.column = current.column->replicate(any_array->getOffsets());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ArrayJoinAction::finalize(NameSet & needed_columns, NameSet & unmodified_columns, NameSet & final_columns)
|
||||
{
|
||||
/// Do not ARRAY JOIN columns that are not used anymore.
|
||||
/// Usually, such columns are not used until ARRAY JOIN, and therefore are ejected further in this function.
|
||||
/// We will not remove all the columns so as not to lose the number of rows.
|
||||
for (auto it = columns.begin(); it != columns.end();)
|
||||
{
|
||||
bool need = needed_columns.count(*it);
|
||||
if (!need && columns.size() > 1)
|
||||
{
|
||||
columns.erase(it++);
|
||||
}
|
||||
else
|
||||
{
|
||||
needed_columns.insert(*it);
|
||||
unmodified_columns.erase(*it);
|
||||
|
||||
/// If no ARRAY JOIN results are used, forcibly leave an arbitrary column at the output,
|
||||
/// so you do not lose the number of rows.
|
||||
if (!need)
|
||||
final_columns.insert(*it);
|
||||
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
35
dbms/src/Interpreters/ArrayJoinAction.h
Normal file
35
dbms/src/Interpreters/ArrayJoinAction.h
Normal file
@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
#include <Core/Names.h>
|
||||
#include <Core/Block.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class Context;
|
||||
|
||||
class IFunctionOverloadResolver;
|
||||
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;
|
||||
|
||||
struct ArrayJoinAction
|
||||
{
|
||||
NameSet columns;
|
||||
bool is_left = false;
|
||||
bool is_unaligned = false;
|
||||
|
||||
/// For unaligned [LEFT] ARRAY JOIN
|
||||
FunctionOverloadResolverPtr function_length;
|
||||
FunctionOverloadResolverPtr function_greatest;
|
||||
FunctionOverloadResolverPtr function_arrayResize;
|
||||
|
||||
/// For LEFT ARRAY JOIN.
|
||||
FunctionOverloadResolverPtr function_builder;
|
||||
|
||||
ArrayJoinAction(const NameSet & array_joined_columns_, bool array_join_is_left, const Context & context);
|
||||
void prepare(Block & sample_block);
|
||||
void execute(Block & block, bool dry_run);
|
||||
void finalize(NameSet & needed_columns, NameSet & unmodified_columns, NameSet & final_columns);
|
||||
};
|
||||
|
||||
}
|
@ -57,6 +57,7 @@
|
||||
#include <Common/TraceCollector.h>
|
||||
#include <common/logger_useful.h>
|
||||
#include <Common/RemoteHostFilter.h>
|
||||
#include <ext/singleton.h>
|
||||
|
||||
namespace ProfileEvents
|
||||
{
|
||||
@ -168,7 +169,6 @@ struct ContextShared
|
||||
|
||||
RemoteHostFilter remote_host_filter; /// Allowed URL from config.xml
|
||||
|
||||
std::unique_ptr<TraceCollector> trace_collector; /// Thread collecting traces from threads executing queries
|
||||
/// Named sessions. The user could specify session identifier to reuse settings and temporary tables in subsequent requests.
|
||||
|
||||
class SessionKeyHash
|
||||
@ -299,13 +299,7 @@ struct ContextShared
|
||||
schedule_pool.reset();
|
||||
ddl_worker.reset();
|
||||
|
||||
/// Stop trace collector if any
|
||||
trace_collector.reset();
|
||||
}
|
||||
|
||||
bool hasTraceCollector()
|
||||
{
|
||||
return trace_collector != nullptr;
|
||||
ext::Singleton<TraceCollector>::reset();
|
||||
}
|
||||
|
||||
void initializeTraceCollector(std::shared_ptr<TraceLog> trace_log)
|
||||
@ -313,7 +307,7 @@ struct ContextShared
|
||||
if (trace_log == nullptr)
|
||||
return;
|
||||
|
||||
trace_collector = std::make_unique<TraceCollector>(trace_log);
|
||||
ext::Singleton<TraceCollector>()->setTraceLog(trace_log);
|
||||
}
|
||||
};
|
||||
|
||||
@ -650,6 +644,10 @@ void Context::checkAccess(const AccessFlags & access, const std::string_view & d
|
||||
void Context::checkAccess(const AccessRightsElement & access) const { return checkAccessImpl(access); }
|
||||
void Context::checkAccess(const AccessRightsElements & access) const { return checkAccessImpl(access); }
|
||||
|
||||
void Context::switchRowPolicy()
|
||||
{
|
||||
row_policy = getAccessControlManager().getRowPolicyContext(client_info.initial_user);
|
||||
}
|
||||
|
||||
void Context::setUsersConfig(const ConfigurationPtr & config)
|
||||
{
|
||||
@ -688,7 +686,7 @@ void Context::calculateAccessRights()
|
||||
{
|
||||
auto lock = getLock();
|
||||
if (user)
|
||||
std::atomic_store(&access_rights, getAccessControlManager().getAccessRightsContext(client_info, user->access, settings, current_database));
|
||||
std::atomic_store(&access_rights, getAccessControlManager().getAccessRightsContext(user, client_info, settings, current_database));
|
||||
}
|
||||
|
||||
void Context::setProfile(const String & profile)
|
||||
@ -701,9 +699,18 @@ void Context::setProfile(const String & profile)
|
||||
settings_constraints = std::move(new_constraints);
|
||||
}
|
||||
|
||||
std::shared_ptr<const User> Context::getUser(const String & user_name) const
|
||||
std::shared_ptr<const User> Context::getUser() const
|
||||
{
|
||||
return shared->access_control_manager.getUser(user_name);
|
||||
if (!user)
|
||||
throw Exception("No current user", ErrorCodes::LOGICAL_ERROR);
|
||||
return user;
|
||||
}
|
||||
|
||||
UUID Context::getUserID() const
|
||||
{
|
||||
if (!user)
|
||||
throw Exception("No current user", ErrorCodes::LOGICAL_ERROR);
|
||||
return user_id;
|
||||
}
|
||||
|
||||
void Context::setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key)
|
||||
@ -717,8 +724,9 @@ void Context::setUser(const String & name, const String & password, const Poco::
|
||||
if (!quota_key.empty())
|
||||
client_info.quota_key = quota_key;
|
||||
|
||||
user_id = shared->access_control_manager.getID<User>(name);
|
||||
user = shared->access_control_manager.authorizeAndGetUser(
|
||||
name,
|
||||
user_id,
|
||||
password,
|
||||
address.host(),
|
||||
[this](const UserPtr & changed_user)
|
||||
@ -1679,11 +1687,6 @@ void Context::initializeSystemLogs()
|
||||
shared->system_logs.emplace(*global_context, getConfigRef());
|
||||
}
|
||||
|
||||
bool Context::hasTraceCollector()
|
||||
{
|
||||
return shared->hasTraceCollector();
|
||||
}
|
||||
|
||||
void Context::initializeTraceCollector()
|
||||
{
|
||||
shared->initializeTraceCollector(getTraceLog());
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <Core/NamesAndTypes.h>
|
||||
#include <Core/Settings.h>
|
||||
#include <Core/Types.h>
|
||||
#include <Core/UUID.h>
|
||||
#include <DataStreams/IBlockStream_fwd.h>
|
||||
#include <Interpreters/ClientInfo.h>
|
||||
#include <Parsers/IAST_fwd.h>
|
||||
@ -161,6 +162,7 @@ private:
|
||||
InputBlocksReader input_blocks_reader;
|
||||
|
||||
std::shared_ptr<const User> user;
|
||||
UUID user_id;
|
||||
SubscriptionForUserChange subscription_for_user_change;
|
||||
std::shared_ptr<const AccessRightsContext> access_rights;
|
||||
std::shared_ptr<QuotaContext> quota; /// Current quota. By default - empty quota, that have no limits.
|
||||
@ -251,6 +253,10 @@ public:
|
||||
std::shared_ptr<QuotaContext> getQuota() const { return quota; }
|
||||
std::shared_ptr<RowPolicyContext> getRowPolicy() const { return row_policy; }
|
||||
|
||||
/// TODO: we need much better code for switching policies, quotas, access rights for initial user
|
||||
/// Switches row policy in case we have initial user in client info
|
||||
void switchRowPolicy();
|
||||
|
||||
/** Take the list of users, quotas and configuration profiles from this config.
|
||||
* The list of users is completely replaced.
|
||||
* The accumulated quota values are not reset if the quota is not deleted.
|
||||
@ -260,10 +266,8 @@ public:
|
||||
|
||||
/// Must be called before getClientInfo.
|
||||
void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key);
|
||||
std::shared_ptr<const User> getUser() const { return user; }
|
||||
|
||||
/// Used by MySQL Secure Password Authentication plugin.
|
||||
std::shared_ptr<const User> getUser(const String & user_name) const;
|
||||
std::shared_ptr<const User> getUser() const;
|
||||
UUID getUserID() const;
|
||||
|
||||
/// We have to copy external tables inside executeQuery() to track limits. Therefore, set callback for it. Must set once.
|
||||
void setExternalTablesInitializer(ExternalTablesInitializer && initializer);
|
||||
|
@ -6,13 +6,11 @@
|
||||
#include <Interpreters/ExpressionJIT.h>
|
||||
#include <Interpreters/AnalyzedJoin.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/IFunction.h>
|
||||
#include <set>
|
||||
#include <optional>
|
||||
#include <Columns/ColumnSet.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
@ -33,20 +31,20 @@ namespace ErrorCodes
|
||||
extern const int UNKNOWN_IDENTIFIER;
|
||||
extern const int UNKNOWN_ACTION;
|
||||
extern const int NOT_FOUND_COLUMN_IN_BLOCK;
|
||||
extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
|
||||
extern const int TOO_MANY_TEMPORARY_COLUMNS;
|
||||
extern const int TOO_MANY_TEMPORARY_NON_CONST_COLUMNS;
|
||||
extern const int TYPE_MISMATCH;
|
||||
}
|
||||
|
||||
/// Read comment near usage
|
||||
static constexpr auto DUMMY_COLUMN_NAME = "_dummy";
|
||||
|
||||
|
||||
Names ExpressionAction::getNeededColumns() const
|
||||
{
|
||||
Names res = argument_names;
|
||||
|
||||
res.insert(res.end(), array_joined_columns.begin(), array_joined_columns.end());
|
||||
if (array_join)
|
||||
res.insert(res.end(), array_join->columns.begin(), array_join->columns.end());
|
||||
|
||||
if (table_join)
|
||||
res.insert(res.end(), table_join->keyNamesLeft().begin(), table_join->keyNamesLeft().end());
|
||||
@ -143,23 +141,9 @@ ExpressionAction ExpressionAction::addAliases(const NamesWithAliases & aliased_c
|
||||
|
||||
ExpressionAction ExpressionAction::arrayJoin(const NameSet & array_joined_columns, bool array_join_is_left, const Context & context)
|
||||
{
|
||||
if (array_joined_columns.empty())
|
||||
throw Exception("No arrays to join", ErrorCodes::LOGICAL_ERROR);
|
||||
ExpressionAction a;
|
||||
a.type = ARRAY_JOIN;
|
||||
a.array_joined_columns = array_joined_columns;
|
||||
a.array_join_is_left = array_join_is_left;
|
||||
a.unaligned_array_join = context.getSettingsRef().enable_unaligned_array_join;
|
||||
|
||||
if (a.unaligned_array_join)
|
||||
{
|
||||
a.function_length = FunctionFactory::instance().get("length", context);
|
||||
a.function_greatest = FunctionFactory::instance().get("greatest", context);
|
||||
a.function_arrayResize = FunctionFactory::instance().get("arrayResize", context);
|
||||
}
|
||||
else if (array_join_is_left)
|
||||
a.function_builder = FunctionFactory::instance().get("emptyArrayToSingle", context);
|
||||
|
||||
a.array_join = std::make_shared<ArrayJoinAction>(array_joined_columns, array_join_is_left, context);
|
||||
return a;
|
||||
}
|
||||
|
||||
@ -172,7 +156,6 @@ ExpressionAction ExpressionAction::ordinaryJoin(std::shared_ptr<AnalyzedJoin> ta
|
||||
return a;
|
||||
}
|
||||
|
||||
|
||||
void ExpressionAction::prepare(Block & sample_block, const Settings & settings, NameSet & names_not_for_constant_folding)
|
||||
{
|
||||
// std::cerr << "preparing: " << toString() << std::endl;
|
||||
@ -256,16 +239,7 @@ void ExpressionAction::prepare(Block & sample_block, const Settings & settings,
|
||||
|
||||
case ARRAY_JOIN:
|
||||
{
|
||||
for (const auto & name : array_joined_columns)
|
||||
{
|
||||
ColumnWithTypeAndName & current = sample_block.getByName(name);
|
||||
const DataTypeArray * array_type = typeid_cast<const DataTypeArray *>(&*current.type);
|
||||
if (!array_type)
|
||||
throw Exception("ARRAY JOIN requires array argument", ErrorCodes::TYPE_MISMATCH);
|
||||
current.type = array_type->getNestedType();
|
||||
current.column = nullptr;
|
||||
}
|
||||
|
||||
array_join->prepare(sample_block);
|
||||
break;
|
||||
}
|
||||
|
||||
@ -383,95 +357,7 @@ void ExpressionAction::execute(Block & block, bool dry_run, ExtraBlockPtr & not_
|
||||
|
||||
case ARRAY_JOIN:
|
||||
{
|
||||
if (array_joined_columns.empty())
|
||||
throw Exception("No arrays to join", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
ColumnPtr any_array_ptr = block.getByName(*array_joined_columns.begin()).column->convertToFullColumnIfConst();
|
||||
const ColumnArray * any_array = typeid_cast<const ColumnArray *>(&*any_array_ptr);
|
||||
if (!any_array)
|
||||
throw Exception("ARRAY JOIN of not array: " + *array_joined_columns.begin(), ErrorCodes::TYPE_MISMATCH);
|
||||
|
||||
/// If LEFT ARRAY JOIN, then we create columns in which empty arrays are replaced by arrays with one element - the default value.
|
||||
std::map<String, ColumnPtr> non_empty_array_columns;
|
||||
|
||||
if (unaligned_array_join)
|
||||
{
|
||||
/// Resize all array joined columns to the longest one, (at least 1 if LEFT ARRAY JOIN), padded with default values.
|
||||
auto rows = block.rows();
|
||||
auto uint64 = std::make_shared<DataTypeUInt64>();
|
||||
ColumnWithTypeAndName column_of_max_length;
|
||||
if (array_join_is_left)
|
||||
column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 1u), uint64, {});
|
||||
else
|
||||
column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 0u), uint64, {});
|
||||
|
||||
for (const auto & name : array_joined_columns)
|
||||
{
|
||||
auto & src_col = block.getByName(name);
|
||||
|
||||
Block tmp_block{src_col, {{}, uint64, {}}};
|
||||
function_length->build({src_col})->execute(tmp_block, {0}, 1, rows);
|
||||
|
||||
Block tmp_block2{
|
||||
column_of_max_length, tmp_block.safeGetByPosition(1), {{}, uint64, {}}};
|
||||
function_greatest->build({column_of_max_length, tmp_block.safeGetByPosition(1)})->execute(tmp_block2, {0, 1}, 2, rows);
|
||||
column_of_max_length = tmp_block2.safeGetByPosition(2);
|
||||
}
|
||||
|
||||
for (const auto & name : array_joined_columns)
|
||||
{
|
||||
auto & src_col = block.getByName(name);
|
||||
|
||||
Block tmp_block{src_col, column_of_max_length, {{}, src_col.type, {}}};
|
||||
function_arrayResize->build({src_col, column_of_max_length})->execute(tmp_block, {0, 1}, 2, rows);
|
||||
src_col.column = tmp_block.safeGetByPosition(2).column;
|
||||
any_array_ptr = src_col.column->convertToFullColumnIfConst();
|
||||
}
|
||||
|
||||
any_array = typeid_cast<const ColumnArray *>(&*any_array_ptr);
|
||||
}
|
||||
else if (array_join_is_left)
|
||||
{
|
||||
for (const auto & name : array_joined_columns)
|
||||
{
|
||||
auto src_col = block.getByName(name);
|
||||
|
||||
Block tmp_block{src_col, {{}, src_col.type, {}}};
|
||||
|
||||
function_builder->build({src_col})->execute(tmp_block, {0}, 1, src_col.column->size(), dry_run);
|
||||
non_empty_array_columns[name] = tmp_block.safeGetByPosition(1).column;
|
||||
}
|
||||
|
||||
any_array_ptr = non_empty_array_columns.begin()->second->convertToFullColumnIfConst();
|
||||
any_array = &typeid_cast<const ColumnArray &>(*any_array_ptr);
|
||||
}
|
||||
|
||||
size_t columns = block.columns();
|
||||
for (size_t i = 0; i < columns; ++i)
|
||||
{
|
||||
ColumnWithTypeAndName & current = block.safeGetByPosition(i);
|
||||
|
||||
if (array_joined_columns.count(current.name))
|
||||
{
|
||||
if (!typeid_cast<const DataTypeArray *>(&*current.type))
|
||||
throw Exception("ARRAY JOIN of not array: " + current.name, ErrorCodes::TYPE_MISMATCH);
|
||||
|
||||
ColumnPtr array_ptr = (array_join_is_left && !unaligned_array_join) ? non_empty_array_columns[current.name] : current.column;
|
||||
array_ptr = array_ptr->convertToFullColumnIfConst();
|
||||
|
||||
const ColumnArray & array = typeid_cast<const ColumnArray &>(*array_ptr);
|
||||
if (!unaligned_array_join && !array.hasEqualOffsets(typeid_cast<const ColumnArray &>(*any_array_ptr)))
|
||||
throw Exception("Sizes of ARRAY-JOIN-ed arrays do not match", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
|
||||
|
||||
current.column = typeid_cast<const ColumnArray &>(*array_ptr).getDataPtr();
|
||||
current.type = typeid_cast<const DataTypeArray &>(*current.type).getNestedType();
|
||||
}
|
||||
else
|
||||
{
|
||||
current.column = current.column->replicate(any_array->getOffsets());
|
||||
}
|
||||
}
|
||||
|
||||
array_join->execute(block, dry_run);
|
||||
break;
|
||||
}
|
||||
|
||||
@ -539,7 +425,6 @@ void ExpressionAction::execute(Block & block, bool dry_run, ExtraBlockPtr & not_
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void ExpressionAction::executeOnTotals(Block & block) const
|
||||
{
|
||||
if (type != JOIN)
|
||||
@ -584,10 +469,10 @@ std::string ExpressionAction::toString() const
|
||||
break;
|
||||
|
||||
case ARRAY_JOIN:
|
||||
ss << (array_join_is_left ? "LEFT " : "") << "ARRAY JOIN ";
|
||||
for (NameSet::const_iterator it = array_joined_columns.begin(); it != array_joined_columns.end(); ++it)
|
||||
ss << (array_join->is_left ? "LEFT " : "") << "ARRAY JOIN ";
|
||||
for (NameSet::const_iterator it = array_join->columns.begin(); it != array_join->columns.end(); ++it)
|
||||
{
|
||||
if (it != array_joined_columns.begin())
|
||||
if (it != array_join->columns.begin())
|
||||
ss << ", ";
|
||||
ss << *it;
|
||||
}
|
||||
@ -675,7 +560,9 @@ void ExpressionActions::addImpl(ExpressionAction action, Names & new_names)
|
||||
{
|
||||
if (action.result_name != "")
|
||||
new_names.push_back(action.result_name);
|
||||
new_names.insert(new_names.end(), action.array_joined_columns.begin(), action.array_joined_columns.end());
|
||||
|
||||
if (action.array_join)
|
||||
new_names.insert(new_names.end(), action.array_join->columns.begin(), action.array_join->columns.end());
|
||||
|
||||
/// Compiled functions are custom functions and they don't need building
|
||||
if (action.type == ExpressionAction::APPLY_FUNCTION && !action.is_function_compiled)
|
||||
@ -713,7 +600,7 @@ void ExpressionActions::prependArrayJoin(const ExpressionAction & action, const
|
||||
if (action.type != ExpressionAction::ARRAY_JOIN)
|
||||
throw Exception("ARRAY_JOIN action expected", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
NameSet array_join_set(action.array_joined_columns.begin(), action.array_joined_columns.end());
|
||||
NameSet array_join_set(action.array_join->columns.begin(), action.array_join->columns.end());
|
||||
for (auto & it : input_columns)
|
||||
{
|
||||
if (array_join_set.count(it.name))
|
||||
@ -738,12 +625,12 @@ bool ExpressionActions::popUnusedArrayJoin(const Names & required_columns, Expre
|
||||
if (actions.empty() || actions.back().type != ExpressionAction::ARRAY_JOIN)
|
||||
return false;
|
||||
NameSet required_set(required_columns.begin(), required_columns.end());
|
||||
for (const std::string & name : actions.back().array_joined_columns)
|
||||
for (const std::string & name : actions.back().array_join->columns)
|
||||
{
|
||||
if (required_set.count(name))
|
||||
return false;
|
||||
}
|
||||
for (const std::string & name : actions.back().array_joined_columns)
|
||||
for (const std::string & name : actions.back().array_join->columns)
|
||||
{
|
||||
DataTypePtr & type = sample_block.getByName(name).type;
|
||||
type = std::make_shared<DataTypeArray>(type);
|
||||
@ -884,29 +771,7 @@ void ExpressionActions::finalize(const Names & output_columns)
|
||||
}
|
||||
else if (action.type == ExpressionAction::ARRAY_JOIN)
|
||||
{
|
||||
/// Do not ARRAY JOIN columns that are not used anymore.
|
||||
/// Usually, such columns are not used until ARRAY JOIN, and therefore are ejected further in this function.
|
||||
/// We will not remove all the columns so as not to lose the number of rows.
|
||||
for (auto it = action.array_joined_columns.begin(); it != action.array_joined_columns.end();)
|
||||
{
|
||||
bool need = needed_columns.count(*it);
|
||||
if (!need && action.array_joined_columns.size() > 1)
|
||||
{
|
||||
action.array_joined_columns.erase(it++);
|
||||
}
|
||||
else
|
||||
{
|
||||
needed_columns.insert(*it);
|
||||
unmodified_columns.erase(*it);
|
||||
|
||||
/// If no ARRAY JOIN results are used, forcibly leave an arbitrary column at the output,
|
||||
/// so you do not lose the number of rows.
|
||||
if (!need)
|
||||
final_columns.insert(*it);
|
||||
|
||||
++it;
|
||||
}
|
||||
}
|
||||
action.array_join->finalize(needed_columns, unmodified_columns, final_columns);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -1143,7 +1008,8 @@ void ExpressionActions::optimizeArrayJoin()
|
||||
|
||||
if (actions[i].result_name != "")
|
||||
array_joined_columns.insert(actions[i].result_name);
|
||||
array_joined_columns.insert(actions[i].array_joined_columns.begin(), actions[i].array_joined_columns.end());
|
||||
if (actions[i].array_join)
|
||||
array_joined_columns.insert(actions[i].array_join->columns.begin(), actions[i].array_join->columns.end());
|
||||
|
||||
array_join_dependencies.insert(needed.begin(), needed.end());
|
||||
}
|
||||
@ -1274,8 +1140,8 @@ UInt128 ExpressionAction::ActionHash::operator()(const ExpressionAction & action
|
||||
hash.update(arg_name);
|
||||
break;
|
||||
case ARRAY_JOIN:
|
||||
hash.update(action.array_join_is_left);
|
||||
for (const auto & col : action.array_joined_columns)
|
||||
hash.update(action.array_join->is_left);
|
||||
for (const auto & col : action.array_join->columns)
|
||||
hash.update(col);
|
||||
break;
|
||||
case JOIN:
|
||||
@ -1332,11 +1198,15 @@ bool ExpressionAction::operator==(const ExpressionAction & other) const
|
||||
return false;
|
||||
}
|
||||
|
||||
bool same_array_join = !array_join && !other.array_join;
|
||||
if (array_join && other.array_join)
|
||||
same_array_join = (array_join->columns == other.array_join->columns) &&
|
||||
(array_join->is_left == other.array_join->is_left);
|
||||
|
||||
return source_name == other.source_name
|
||||
&& result_name == other.result_name
|
||||
&& argument_names == other.argument_names
|
||||
&& array_joined_columns == other.array_joined_columns
|
||||
&& array_join_is_left == other.array_join_is_left
|
||||
&& same_array_join
|
||||
&& AnalyzedJoin::sameJoin(table_join.get(), other.table_join.get())
|
||||
&& projection == other.projection
|
||||
&& is_function_compiled == other.is_function_compiled;
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <Parsers/ASTTablesInSelectQuery.h>
|
||||
#include <Interpreters/ArrayJoinAction.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -81,15 +82,10 @@ public:
|
||||
/// For ADD_COLUMN.
|
||||
ColumnPtr added_column;
|
||||
|
||||
/// For APPLY_FUNCTION and LEFT ARRAY JOIN.
|
||||
/// For APPLY_FUNCTION.
|
||||
/// OverloadResolver is used before action was added to ExpressionActions (when we don't know types of arguments).
|
||||
FunctionOverloadResolverPtr function_builder;
|
||||
|
||||
/// For unaligned [LEFT] ARRAY JOIN
|
||||
FunctionOverloadResolverPtr function_length;
|
||||
FunctionOverloadResolverPtr function_greatest;
|
||||
FunctionOverloadResolverPtr function_arrayResize;
|
||||
|
||||
/// Can be used after action was added to ExpressionActions if we want to get function signature or properties like monotonicity.
|
||||
FunctionBasePtr function_base;
|
||||
/// Prepared function which is used in function execution.
|
||||
@ -97,10 +93,8 @@ public:
|
||||
Names argument_names;
|
||||
bool is_function_compiled = false;
|
||||
|
||||
/// For ARRAY_JOIN
|
||||
NameSet array_joined_columns;
|
||||
bool array_join_is_left = false;
|
||||
bool unaligned_array_join = false;
|
||||
/// For ARRAY JOIN
|
||||
std::shared_ptr<ArrayJoinAction> array_join;
|
||||
|
||||
/// For JOIN
|
||||
std::shared_ptr<const AnalyzedJoin> table_join;
|
||||
|
@ -70,9 +70,51 @@ using LogAST = DebugASTLog<false>; /// set to true to enable logs
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int UNKNOWN_IDENTIFIER;
|
||||
extern const int ILLEGAL_PREWHERE;
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
/// Check if there is an ignore function. It's used for disabling constant folding in query
|
||||
/// predicates because some performance tests use ignore function as a non-optimize guard.
|
||||
bool allowEarlyConstantFolding(const ExpressionActions & actions, const Settings & settings)
|
||||
{
|
||||
if (!settings.enable_early_constant_folding)
|
||||
return false;
|
||||
|
||||
for (auto & action : actions.getActions())
|
||||
{
|
||||
if (action.type == action.APPLY_FUNCTION && action.function_base)
|
||||
{
|
||||
auto name = action.function_base->getName();
|
||||
if (name == "ignore")
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
bool sanitizeBlock(Block & block)
|
||||
{
|
||||
for (auto & col : block)
|
||||
{
|
||||
if (!col.column)
|
||||
{
|
||||
if (isNotCreatable(col.type->getTypeId()))
|
||||
return false;
|
||||
col.column = col.type->createColumn();
|
||||
}
|
||||
else if (isColumnConst(*col.column) && !col.column->empty())
|
||||
col.column = col.column->cloneEmpty();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
ExpressionAnalyzer::ExpressionAnalyzer(
|
||||
const ASTPtr & query_,
|
||||
const SyntaxAnalyzerResultPtr & syntax_analyzer_result_,
|
||||
@ -733,7 +775,8 @@ void SelectQueryExpressionAnalyzer::appendSelect(ExpressionActionsChain & chain,
|
||||
step.required_output.push_back(child->getColumnName());
|
||||
}
|
||||
|
||||
bool SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain, bool only_types, bool optimize_read_in_order)
|
||||
bool SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain, bool only_types, bool optimize_read_in_order,
|
||||
ManyExpressionActions & order_by_elements_actions)
|
||||
{
|
||||
const auto * select_query = getSelectQuery();
|
||||
|
||||
@ -884,12 +927,239 @@ ExpressionActionsPtr ExpressionAnalyzer::getConstActions()
|
||||
return actions;
|
||||
}
|
||||
|
||||
void SelectQueryExpressionAnalyzer::getAggregateInfo(Names & key_names, AggregateDescriptions & aggregates) const
|
||||
ExpressionActionsPtr SelectQueryExpressionAnalyzer::simpleSelectActions()
|
||||
{
|
||||
for (const auto & name_and_type : aggregation_keys)
|
||||
key_names.emplace_back(name_and_type.name);
|
||||
ExpressionActionsChain new_chain(context);
|
||||
appendSelect(new_chain, false);
|
||||
return new_chain.getLastActions();
|
||||
}
|
||||
|
||||
aggregates = aggregate_descriptions;
|
||||
ExpressionAnalysisResult::ExpressionAnalysisResult(
|
||||
SelectQueryExpressionAnalyzer & query_analyzer,
|
||||
bool first_stage_,
|
||||
bool second_stage_,
|
||||
bool only_types,
|
||||
const FilterInfoPtr & filter_info_,
|
||||
const Block & source_header)
|
||||
: first_stage(first_stage_)
|
||||
, second_stage(second_stage_)
|
||||
, need_aggregate(query_analyzer.hasAggregation())
|
||||
{
|
||||
/// first_stage: Do I need to perform the first part of the pipeline - running on remote servers during distributed processing.
|
||||
/// second_stage: Do I need to execute the second part of the pipeline - running on the initiating server during distributed processing.
|
||||
|
||||
/** First we compose a chain of actions and remember the necessary steps from it.
|
||||
* Regardless of from_stage and to_stage, we will compose a complete sequence of actions to perform optimization and
|
||||
* throw out unnecessary columns based on the entire query. In unnecessary parts of the query, we will not execute subqueries.
|
||||
*/
|
||||
|
||||
const ASTSelectQuery & query = *query_analyzer.getSelectQuery();
|
||||
const Context & context = query_analyzer.context;
|
||||
const Settings & settings = context.getSettingsRef();
|
||||
const StoragePtr & storage = query_analyzer.storage();
|
||||
|
||||
bool finalized = false;
|
||||
size_t where_step_num = 0;
|
||||
|
||||
auto finalizeChain = [&](ExpressionActionsChain & chain)
|
||||
{
|
||||
if (!finalized)
|
||||
{
|
||||
chain.finalize();
|
||||
finalize(chain, context, where_step_num);
|
||||
chain.clear();
|
||||
}
|
||||
finalized = true;
|
||||
};
|
||||
|
||||
{
|
||||
ExpressionActionsChain chain(context);
|
||||
Names additional_required_columns_after_prewhere;
|
||||
|
||||
if (storage && (query.sample_size() || settings.parallel_replicas_count > 1))
|
||||
{
|
||||
Names columns_for_sampling = storage->getColumnsRequiredForSampling();
|
||||
additional_required_columns_after_prewhere.insert(additional_required_columns_after_prewhere.end(),
|
||||
columns_for_sampling.begin(), columns_for_sampling.end());
|
||||
}
|
||||
|
||||
if (storage && query.final())
|
||||
{
|
||||
Names columns_for_final = storage->getColumnsRequiredForFinal();
|
||||
additional_required_columns_after_prewhere.insert(additional_required_columns_after_prewhere.end(),
|
||||
columns_for_final.begin(), columns_for_final.end());
|
||||
}
|
||||
|
||||
if (storage && filter_info_)
|
||||
{
|
||||
filter_info = filter_info_;
|
||||
query_analyzer.appendPreliminaryFilter(chain, filter_info->actions, filter_info->column_name);
|
||||
}
|
||||
|
||||
if (query_analyzer.appendPrewhere(chain, !first_stage, additional_required_columns_after_prewhere))
|
||||
{
|
||||
prewhere_info = std::make_shared<PrewhereInfo>(
|
||||
chain.steps.front().actions, query.prewhere()->getColumnName());
|
||||
|
||||
if (allowEarlyConstantFolding(*prewhere_info->prewhere_actions, settings))
|
||||
{
|
||||
Block before_prewhere_sample = source_header;
|
||||
if (sanitizeBlock(before_prewhere_sample))
|
||||
{
|
||||
prewhere_info->prewhere_actions->execute(before_prewhere_sample);
|
||||
auto & column_elem = before_prewhere_sample.getByName(query.prewhere()->getColumnName());
|
||||
/// If the filter column is a constant, record it.
|
||||
if (column_elem.column)
|
||||
prewhere_constant_filter_description = ConstantFilterDescription(*column_elem.column);
|
||||
}
|
||||
}
|
||||
chain.addStep();
|
||||
}
|
||||
|
||||
query_analyzer.appendArrayJoin(chain, only_types || !first_stage);
|
||||
|
||||
if (query_analyzer.appendJoin(chain, only_types || !first_stage))
|
||||
{
|
||||
before_join = chain.getLastActions();
|
||||
if (!hasJoin())
|
||||
throw Exception("No expected JOIN", ErrorCodes::LOGICAL_ERROR);
|
||||
chain.addStep();
|
||||
}
|
||||
|
||||
if (query_analyzer.appendWhere(chain, only_types || !first_stage))
|
||||
{
|
||||
where_step_num = chain.steps.size() - 1;
|
||||
before_where = chain.getLastActions();
|
||||
if (allowEarlyConstantFolding(*before_where, settings))
|
||||
{
|
||||
Block before_where_sample;
|
||||
if (chain.steps.size() > 1)
|
||||
before_where_sample = chain.steps[chain.steps.size() - 2].actions->getSampleBlock();
|
||||
else
|
||||
before_where_sample = source_header;
|
||||
if (sanitizeBlock(before_where_sample))
|
||||
{
|
||||
before_where->execute(before_where_sample);
|
||||
auto & column_elem = before_where_sample.getByName(query.where()->getColumnName());
|
||||
/// If the filter column is a constant, record it.
|
||||
if (column_elem.column)
|
||||
where_constant_filter_description = ConstantFilterDescription(*column_elem.column);
|
||||
}
|
||||
}
|
||||
chain.addStep();
|
||||
}
|
||||
|
||||
if (need_aggregate)
|
||||
{
|
||||
query_analyzer.appendGroupBy(chain, only_types || !first_stage);
|
||||
query_analyzer.appendAggregateFunctionsArguments(chain, only_types || !first_stage);
|
||||
before_aggregation = chain.getLastActions();
|
||||
|
||||
finalizeChain(chain);
|
||||
|
||||
if (query_analyzer.appendHaving(chain, only_types || !second_stage))
|
||||
{
|
||||
before_having = chain.getLastActions();
|
||||
chain.addStep();
|
||||
}
|
||||
}
|
||||
|
||||
bool has_stream_with_non_joned_rows = (before_join && before_join->getTableJoinAlgo()->hasStreamWithNonJoinedRows());
|
||||
optimize_read_in_order =
|
||||
settings.optimize_read_in_order
|
||||
&& storage && query.orderBy()
|
||||
&& !query_analyzer.hasAggregation()
|
||||
&& !query.final()
|
||||
&& !has_stream_with_non_joned_rows;
|
||||
|
||||
/// If there is aggregation, we execute expressions in SELECT and ORDER BY on the initiating server, otherwise on the source servers.
|
||||
query_analyzer.appendSelect(chain, only_types || (need_aggregate ? !second_stage : !first_stage));
|
||||
selected_columns = chain.getLastStep().required_output;
|
||||
has_order_by = query_analyzer.appendOrderBy(chain, only_types || (need_aggregate ? !second_stage : !first_stage),
|
||||
optimize_read_in_order, order_by_elements_actions);
|
||||
before_order_and_select = chain.getLastActions();
|
||||
chain.addStep();
|
||||
|
||||
if (query_analyzer.appendLimitBy(chain, only_types || !second_stage))
|
||||
{
|
||||
before_limit_by = chain.getLastActions();
|
||||
chain.addStep();
|
||||
}
|
||||
|
||||
query_analyzer.appendProjectResult(chain);
|
||||
final_projection = chain.getLastActions();
|
||||
|
||||
finalizeChain(chain);
|
||||
}
|
||||
|
||||
/// Before executing WHERE and HAVING, remove the extra columns from the block (mostly the aggregation keys).
|
||||
removeExtraColumns();
|
||||
|
||||
checkActions();
|
||||
}
|
||||
|
||||
void ExpressionAnalysisResult::finalize(const ExpressionActionsChain & chain, const Context & context_, size_t where_step_num)
|
||||
{
|
||||
if (hasPrewhere())
|
||||
{
|
||||
const ExpressionActionsChain::Step & step = chain.steps.at(0);
|
||||
prewhere_info->remove_prewhere_column = step.can_remove_required_output.at(0);
|
||||
|
||||
Names columns_to_remove;
|
||||
for (size_t i = 1; i < step.required_output.size(); ++i)
|
||||
{
|
||||
if (step.can_remove_required_output[i])
|
||||
columns_to_remove.push_back(step.required_output[i]);
|
||||
}
|
||||
|
||||
if (!columns_to_remove.empty())
|
||||
{
|
||||
auto columns = prewhere_info->prewhere_actions->getSampleBlock().getNamesAndTypesList();
|
||||
ExpressionActionsPtr actions = std::make_shared<ExpressionActions>(columns, context_);
|
||||
for (const auto & column : columns_to_remove)
|
||||
actions->add(ExpressionAction::removeColumn(column));
|
||||
|
||||
prewhere_info->remove_columns_actions = std::move(actions);
|
||||
}
|
||||
|
||||
columns_to_remove_after_prewhere = std::move(columns_to_remove);
|
||||
}
|
||||
else if (hasFilter())
|
||||
{
|
||||
/// Can't have prewhere and filter set simultaneously
|
||||
filter_info->do_remove_column = chain.steps.at(0).can_remove_required_output.at(0);
|
||||
}
|
||||
if (hasWhere())
|
||||
remove_where_filter = chain.steps.at(where_step_num).can_remove_required_output.at(0);
|
||||
}
|
||||
|
||||
void ExpressionAnalysisResult::removeExtraColumns()
|
||||
{
|
||||
if (hasFilter())
|
||||
filter_info->actions->prependProjectInput();
|
||||
if (hasWhere())
|
||||
before_where->prependProjectInput();
|
||||
if (hasHaving())
|
||||
before_having->prependProjectInput();
|
||||
}
|
||||
|
||||
void ExpressionAnalysisResult::checkActions()
|
||||
{
|
||||
/// Check that PREWHERE doesn't contain unusual actions. Unusual actions are that can change number of rows.
|
||||
if (hasPrewhere())
|
||||
{
|
||||
auto check_actions = [](const ExpressionActionsPtr & actions)
|
||||
{
|
||||
if (actions)
|
||||
for (const auto & action : actions->getActions())
|
||||
if (action.type == ExpressionAction::Type::JOIN || action.type == ExpressionAction::Type::ARRAY_JOIN)
|
||||
throw Exception("PREWHERE cannot contain ARRAY JOIN or JOIN action", ErrorCodes::ILLEGAL_PREWHERE);
|
||||
};
|
||||
|
||||
check_actions(prewhere_info->prewhere_actions);
|
||||
check_actions(prewhere_info->alias_actions);
|
||||
check_actions(prewhere_info->remove_columns_actions);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -2,11 +2,13 @@
|
||||
|
||||
#include <Core/Settings.h>
|
||||
#include <DataStreams/IBlockStream_fwd.h>
|
||||
#include <Columns/FilterDescription.h>
|
||||
#include <Interpreters/AggregateDescription.h>
|
||||
#include <Interpreters/SyntaxAnalyzer.h>
|
||||
#include <Interpreters/SubqueryForSet.h>
|
||||
#include <Parsers/IAST_fwd.h>
|
||||
#include <Storages/IStorage_fwd.h>
|
||||
#include <Storages/SelectQueryInfo.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -29,6 +31,9 @@ class ASTExpressionList;
|
||||
class ASTSelectQuery;
|
||||
struct ASTTablesInSelectQueryElement;
|
||||
|
||||
/// Create columns in block or return false if not possible
|
||||
bool sanitizeBlock(Block & block);
|
||||
|
||||
/// ExpressionAnalyzer sources, intermediates and results. It splits data and logic, allows to test them separately.
|
||||
struct ExpressionAnalyzerData
|
||||
{
|
||||
@ -47,9 +52,6 @@ struct ExpressionAnalyzerData
|
||||
|
||||
/// All new temporary tables obtained by performing the GLOBAL IN/JOIN subqueries.
|
||||
Tables external_tables;
|
||||
|
||||
/// Actions by every element of ORDER BY
|
||||
ManyExpressionActions order_by_elements_actions;
|
||||
};
|
||||
|
||||
|
||||
@ -156,10 +158,71 @@ protected:
|
||||
bool isRemoteStorage() const;
|
||||
};
|
||||
|
||||
class SelectQueryExpressionAnalyzer;
|
||||
|
||||
/// Result of SelectQueryExpressionAnalyzer: expressions for InterpreterSelectQuery
|
||||
struct ExpressionAnalysisResult
|
||||
{
|
||||
/// Do I need to perform the first part of the pipeline - running on remote servers during distributed processing.
|
||||
bool first_stage = false;
|
||||
/// Do I need to execute the second part of the pipeline - running on the initiating server during distributed processing.
|
||||
bool second_stage = false;
|
||||
|
||||
bool need_aggregate = false;
|
||||
bool has_order_by = false;
|
||||
|
||||
bool remove_where_filter = false;
|
||||
bool optimize_read_in_order = false;
|
||||
|
||||
ExpressionActionsPtr before_join; /// including JOIN
|
||||
ExpressionActionsPtr before_where;
|
||||
ExpressionActionsPtr before_aggregation;
|
||||
ExpressionActionsPtr before_having;
|
||||
ExpressionActionsPtr before_order_and_select;
|
||||
ExpressionActionsPtr before_limit_by;
|
||||
ExpressionActionsPtr final_projection;
|
||||
|
||||
/// Columns from the SELECT list, before renaming them to aliases.
|
||||
Names selected_columns;
|
||||
|
||||
/// Columns will be removed after prewhere actions execution.
|
||||
Names columns_to_remove_after_prewhere;
|
||||
|
||||
PrewhereInfoPtr prewhere_info;
|
||||
FilterInfoPtr filter_info;
|
||||
ConstantFilterDescription prewhere_constant_filter_description;
|
||||
ConstantFilterDescription where_constant_filter_description;
|
||||
/// Actions by every element of ORDER BY
|
||||
ManyExpressionActions order_by_elements_actions;
|
||||
|
||||
ExpressionAnalysisResult() = default;
|
||||
|
||||
ExpressionAnalysisResult(
|
||||
SelectQueryExpressionAnalyzer & query_analyzer,
|
||||
bool first_stage,
|
||||
bool second_stage,
|
||||
bool only_types,
|
||||
const FilterInfoPtr & filter_info,
|
||||
const Block & source_header);
|
||||
|
||||
bool hasFilter() const { return filter_info.get(); }
|
||||
bool hasJoin() const { return before_join.get(); }
|
||||
bool hasPrewhere() const { return prewhere_info.get(); }
|
||||
bool hasWhere() const { return before_where.get(); }
|
||||
bool hasHaving() const { return before_having.get(); }
|
||||
bool hasLimitBy() const { return before_limit_by.get(); }
|
||||
|
||||
void removeExtraColumns();
|
||||
void checkActions();
|
||||
void finalize(const ExpressionActionsChain & chain, const Context & context, size_t where_step_num);
|
||||
};
|
||||
|
||||
/// SelectQuery specific ExpressionAnalyzer part.
|
||||
class SelectQueryExpressionAnalyzer : public ExpressionAnalyzer
|
||||
{
|
||||
public:
|
||||
friend struct ExpressionAnalysisResult;
|
||||
|
||||
SelectQueryExpressionAnalyzer(
|
||||
const ASTPtr & query_,
|
||||
const SyntaxAnalyzerResultPtr & syntax_analyzer_result_,
|
||||
@ -175,16 +238,46 @@ public:
|
||||
bool hasAggregation() const { return has_aggregation; }
|
||||
bool hasGlobalSubqueries() { return has_global_subqueries; }
|
||||
|
||||
/// Get a list of aggregation keys and descriptions of aggregate functions if the query contains GROUP BY.
|
||||
void getAggregateInfo(Names & key_names, AggregateDescriptions & aggregates) const;
|
||||
const NamesAndTypesList & aggregationKeys() const { return aggregation_keys; }
|
||||
const AggregateDescriptions & aggregates() const { return aggregate_descriptions; }
|
||||
|
||||
/// Create Set-s that we make from IN section to use index on them.
|
||||
void makeSetsForIndex(const ASTPtr & node);
|
||||
const PreparedSets & getPreparedSets() const { return prepared_sets; }
|
||||
|
||||
const ManyExpressionActions & getOrderByActions() const { return order_by_elements_actions; }
|
||||
|
||||
/// Tables that will need to be sent to remote servers for distributed query processing.
|
||||
const Tables & getExternalTables() const { return external_tables; }
|
||||
|
||||
ExpressionActionsPtr simpleSelectActions();
|
||||
|
||||
/// These appends are public only for tests
|
||||
void appendSelect(ExpressionActionsChain & chain, bool only_types);
|
||||
/// Deletes all columns except mentioned by SELECT, arranges the remaining columns and renames them to aliases.
|
||||
void appendProjectResult(ExpressionActionsChain & chain) const;
|
||||
|
||||
private:
|
||||
/// If non-empty, ignore all expressions not from this list.
|
||||
NameSet required_result_columns;
|
||||
|
||||
/**
|
||||
* Create Set from a subquery or a table expression in the query. The created set is suitable for using the index.
|
||||
* The set will not be created if its size hits the limit.
|
||||
*/
|
||||
void tryMakeSetForIndexFromSubquery(const ASTPtr & subquery_or_table_name);
|
||||
|
||||
/**
|
||||
* Checks if subquery is not a plain StorageSet.
|
||||
* Because while making set we will read data from StorageSet which is not allowed.
|
||||
* Returns valid SetPtr from StorageSet if the latter is used after IN or nullptr otherwise.
|
||||
*/
|
||||
SetPtr isPlainStorageSetInSubquery(const ASTPtr & subquery_of_table_name);
|
||||
|
||||
JoinPtr makeTableJoin(const ASTTablesInSelectQueryElement & join_element);
|
||||
void makeSubqueryForJoin(const ASTTablesInSelectQueryElement & join_element, NamesWithAliases && required_columns_with_aliases,
|
||||
SubqueryForSet & subquery_for_set) const;
|
||||
|
||||
const ASTSelectQuery * getAggregatingQuery() const;
|
||||
|
||||
/** These methods allow you to build a chain of transformations over a block, that receives values in the desired sections of the query.
|
||||
*
|
||||
* Example usage:
|
||||
@ -213,37 +306,10 @@ public:
|
||||
|
||||
/// After aggregation:
|
||||
bool appendHaving(ExpressionActionsChain & chain, bool only_types);
|
||||
void appendSelect(ExpressionActionsChain & chain, bool only_types);
|
||||
bool appendOrderBy(ExpressionActionsChain & chain, bool only_types, bool optimize_read_in_order);
|
||||
/// appendSelect
|
||||
bool appendOrderBy(ExpressionActionsChain & chain, bool only_types, bool optimize_read_in_order, ManyExpressionActions &);
|
||||
bool appendLimitBy(ExpressionActionsChain & chain, bool only_types);
|
||||
/// Deletes all columns except mentioned by SELECT, arranges the remaining columns and renames them to aliases.
|
||||
void appendProjectResult(ExpressionActionsChain & chain) const;
|
||||
|
||||
/// Create Set-s that we can from IN section to use the index on them.
|
||||
void makeSetsForIndex(const ASTPtr & node);
|
||||
|
||||
private:
|
||||
/// If non-empty, ignore all expressions not from this list.
|
||||
NameSet required_result_columns;
|
||||
|
||||
/**
|
||||
* Create Set from a subquery or a table expression in the query. The created set is suitable for using the index.
|
||||
* The set will not be created if its size hits the limit.
|
||||
*/
|
||||
void tryMakeSetForIndexFromSubquery(const ASTPtr & subquery_or_table_name);
|
||||
|
||||
/**
|
||||
* Checks if subquery is not a plain StorageSet.
|
||||
* Because while making set we will read data from StorageSet which is not allowed.
|
||||
* Returns valid SetPtr from StorageSet if the latter is used after IN or nullptr otherwise.
|
||||
*/
|
||||
SetPtr isPlainStorageSetInSubquery(const ASTPtr & subquery_of_table_name);
|
||||
|
||||
JoinPtr makeTableJoin(const ASTTablesInSelectQueryElement & join_element);
|
||||
void makeSubqueryForJoin(const ASTTablesInSelectQueryElement & join_element, NamesWithAliases && required_columns_with_aliases,
|
||||
SubqueryForSet & subquery_for_set) const;
|
||||
|
||||
const ASTSelectQuery * getAggregatingQuery() const;
|
||||
/// appendProjectResult
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -14,7 +14,7 @@ ExternalModelsLoader::ExternalModelsLoader(Context & context_)
|
||||
: ExternalLoader("external model", &Logger::get("ExternalModelsLoader"))
|
||||
, context(context_)
|
||||
{
|
||||
setConfigSettings({"models", "name", {}});
|
||||
setConfigSettings({"model", "name", {}});
|
||||
enablePeriodicUpdates(true);
|
||||
}
|
||||
|
||||
|
@ -322,7 +322,7 @@ ColumnsDescription InterpreterCreateQuery::getColumnsDescription(const ASTExpres
|
||||
{
|
||||
auto syntax_analyzer_result = SyntaxAnalyzer(context).analyze(default_expr_list, column_names_and_types);
|
||||
const auto actions = ExpressionAnalyzer(default_expr_list, syntax_analyzer_result, context).getActions(true);
|
||||
for (auto action : actions->getActions())
|
||||
for (auto & action : actions->getActions())
|
||||
if (action.type == ExpressionAction::Type::JOIN || action.type == ExpressionAction::Type::ARRAY_JOIN)
|
||||
throw Exception("Cannot CREATE table. Unsupported default value that requires ARRAY JOIN or JOIN action", ErrorCodes::THERE_IS_NO_DEFAULT_VALUE);
|
||||
|
||||
@ -545,10 +545,12 @@ BlockIO InterpreterCreateQuery::createTable(ASTCreateQuery & create)
|
||||
// If this is a stub ATTACH query, read the query definition from the database
|
||||
if (create.attach && !create.storage && !create.columns_list)
|
||||
{
|
||||
bool if_not_exists = create.if_not_exists;
|
||||
// Table SQL definition is available even if the table is detached
|
||||
auto query = context.getDatabase(create.database)->getCreateTableQuery(context, create.table);
|
||||
create = query->as<ASTCreateQuery &>(); // Copy the saved create query, but use ATTACH instead of CREATE
|
||||
create.attach = true;
|
||||
create.if_not_exists = if_not_exists;
|
||||
}
|
||||
|
||||
String current_database = context.getCurrentDatabase();
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <Storages/IndicesDescription.h>
|
||||
#include <Storages/ConstraintsDescription.h>
|
||||
#include <Common/ThreadPool.h>
|
||||
#include <Access/AccessRightsElement.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
|
72
dbms/src/Interpreters/InterpreterCreateUserQuery.cpp
Normal file
72
dbms/src/Interpreters/InterpreterCreateUserQuery.cpp
Normal file
@ -0,0 +1,72 @@
|
||||
#include <Interpreters/InterpreterCreateUserQuery.h>
|
||||
#include <Parsers/ASTCreateUserQuery.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Access/AccessControlManager.h>
|
||||
#include <Access/User.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
BlockIO InterpreterCreateUserQuery::execute()
|
||||
{
|
||||
const auto & query = query_ptr->as<const ASTCreateUserQuery &>();
|
||||
auto & access_control = context.getAccessControlManager();
|
||||
context.checkAccess(query.alter ? AccessType::ALTER_USER : AccessType::CREATE_USER);
|
||||
|
||||
if (query.alter)
|
||||
{
|
||||
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
|
||||
{
|
||||
auto updated_user = typeid_cast<std::shared_ptr<User>>(entity->clone());
|
||||
updateUserFromQuery(*updated_user, query);
|
||||
return updated_user;
|
||||
};
|
||||
if (query.if_exists)
|
||||
{
|
||||
if (auto id = access_control.find<User>(query.name))
|
||||
access_control.tryUpdate(*id, update_func);
|
||||
}
|
||||
else
|
||||
access_control.update(access_control.getID<User>(query.name), update_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto new_user = std::make_shared<User>();
|
||||
updateUserFromQuery(*new_user, query);
|
||||
|
||||
if (query.if_not_exists)
|
||||
access_control.tryInsert(new_user);
|
||||
else if (query.or_replace)
|
||||
access_control.insertOrReplace(new_user);
|
||||
else
|
||||
access_control.insert(new_user);
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
void InterpreterCreateUserQuery::updateUserFromQuery(User & user, const ASTCreateUserQuery & query)
|
||||
{
|
||||
if (query.alter)
|
||||
{
|
||||
if (!query.new_name.empty())
|
||||
user.setName(query.new_name);
|
||||
}
|
||||
else
|
||||
user.setName(query.name);
|
||||
|
||||
if (query.authentication)
|
||||
user.authentication = *query.authentication;
|
||||
|
||||
if (query.hosts)
|
||||
user.allowed_client_hosts = *query.hosts;
|
||||
if (query.remove_hosts)
|
||||
user.allowed_client_hosts.remove(*query.remove_hosts);
|
||||
if (query.add_hosts)
|
||||
user.allowed_client_hosts.add(*query.add_hosts);
|
||||
|
||||
if (query.profile)
|
||||
user.profile = *query.profile;
|
||||
}
|
||||
}
|
26
dbms/src/Interpreters/InterpreterCreateUserQuery.h
Normal file
26
dbms/src/Interpreters/InterpreterCreateUserQuery.h
Normal file
@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include <Interpreters/IInterpreter.h>
|
||||
#include <Parsers/IAST_fwd.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
class ASTCreateUserQuery;
|
||||
struct User;
|
||||
|
||||
|
||||
class InterpreterCreateUserQuery : public IInterpreter
|
||||
{
|
||||
public:
|
||||
InterpreterCreateUserQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {}
|
||||
|
||||
BlockIO execute() override;
|
||||
|
||||
private:
|
||||
void updateUserFromQuery(User & quota, const ASTCreateUserQuery & query);
|
||||
|
||||
ASTPtr query_ptr;
|
||||
Context & context;
|
||||
};
|
||||
}
|
@ -5,6 +5,7 @@
|
||||
#include <Access/AccessFlags.h>
|
||||
#include <Access/Quota.h>
|
||||
#include <Access/RowPolicy.h>
|
||||
#include <Access/User.h>
|
||||
#include <boost/range/algorithm/transform.hpp>
|
||||
|
||||
|
||||
@ -18,6 +19,16 @@ BlockIO InterpreterDropAccessEntityQuery::execute()
|
||||
|
||||
switch (query.kind)
|
||||
{
|
||||
case Kind::USER:
|
||||
{
|
||||
context.checkAccess(AccessType::DROP_USER);
|
||||
if (query.if_exists)
|
||||
access_control.tryRemove(access_control.find<User>(query.names));
|
||||
else
|
||||
access_control.remove(access_control.getIDs<User>(query.names));
|
||||
return {};
|
||||
}
|
||||
|
||||
case Kind::QUOTA:
|
||||
{
|
||||
context.checkAccess(AccessType::DROP_QUOTA);
|
||||
@ -27,6 +38,7 @@ BlockIO InterpreterDropAccessEntityQuery::execute()
|
||||
access_control.remove(access_control.getIDs<Quota>(query.names));
|
||||
return {};
|
||||
}
|
||||
|
||||
case Kind::ROW_POLICY:
|
||||
{
|
||||
context.checkAccess(AccessType::DROP_POLICY);
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <Parsers/ASTAlterQuery.h>
|
||||
#include <Parsers/ASTCheckQuery.h>
|
||||
#include <Parsers/ASTCreateQuery.h>
|
||||
#include <Parsers/ASTCreateUserQuery.h>
|
||||
#include <Parsers/ASTCreateQuotaQuery.h>
|
||||
#include <Parsers/ASTCreateRowPolicyQuery.h>
|
||||
#include <Parsers/ASTDropAccessEntityQuery.h>
|
||||
@ -14,6 +15,7 @@
|
||||
#include <Parsers/ASTSetQuery.h>
|
||||
#include <Parsers/ASTShowCreateAccessEntityQuery.h>
|
||||
#include <Parsers/ASTShowProcesslistQuery.h>
|
||||
#include <Parsers/ASTShowGrantsQuery.h>
|
||||
#include <Parsers/ASTShowQuotasQuery.h>
|
||||
#include <Parsers/ASTShowRowPoliciesQuery.h>
|
||||
#include <Parsers/ASTShowTablesQuery.h>
|
||||
@ -21,10 +23,12 @@
|
||||
#include <Parsers/ASTExplainQuery.h>
|
||||
#include <Parsers/TablePropertiesQueriesASTs.h>
|
||||
#include <Parsers/ASTWatchQuery.h>
|
||||
#include <Parsers/ASTGrantQuery.h>
|
||||
|
||||
#include <Interpreters/InterpreterAlterQuery.h>
|
||||
#include <Interpreters/InterpreterCheckQuery.h>
|
||||
#include <Interpreters/InterpreterCreateQuery.h>
|
||||
#include <Interpreters/InterpreterCreateUserQuery.h>
|
||||
#include <Interpreters/InterpreterCreateQuotaQuery.h>
|
||||
#include <Interpreters/InterpreterCreateRowPolicyQuery.h>
|
||||
#include <Interpreters/InterpreterDescribeQuery.h>
|
||||
@ -43,12 +47,14 @@
|
||||
#include <Interpreters/InterpreterShowCreateAccessEntityQuery.h>
|
||||
#include <Interpreters/InterpreterShowCreateQuery.h>
|
||||
#include <Interpreters/InterpreterShowProcesslistQuery.h>
|
||||
#include <Interpreters/InterpreterShowGrantsQuery.h>
|
||||
#include <Interpreters/InterpreterShowQuotasQuery.h>
|
||||
#include <Interpreters/InterpreterShowRowPoliciesQuery.h>
|
||||
#include <Interpreters/InterpreterShowTablesQuery.h>
|
||||
#include <Interpreters/InterpreterSystemQuery.h>
|
||||
#include <Interpreters/InterpreterUseQuery.h>
|
||||
#include <Interpreters/InterpreterWatchQuery.h>
|
||||
#include <Interpreters/InterpreterGrantQuery.h>
|
||||
|
||||
#include <Parsers/ASTSystemQuery.h>
|
||||
|
||||
@ -176,6 +182,10 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, Context &
|
||||
{
|
||||
return std::make_unique<InterpreterWatchQuery>(query, context);
|
||||
}
|
||||
else if (query->as<ASTCreateUserQuery>())
|
||||
{
|
||||
return std::make_unique<InterpreterCreateUserQuery>(query, context);
|
||||
}
|
||||
else if (query->as<ASTCreateQuotaQuery>())
|
||||
{
|
||||
return std::make_unique<InterpreterCreateQuotaQuery>(query, context);
|
||||
@ -188,10 +198,18 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, Context &
|
||||
{
|
||||
return std::make_unique<InterpreterDropAccessEntityQuery>(query, context);
|
||||
}
|
||||
else if (query->as<ASTGrantQuery>())
|
||||
{
|
||||
return std::make_unique<InterpreterGrantQuery>(query, context);
|
||||
}
|
||||
else if (query->as<ASTShowCreateAccessEntityQuery>())
|
||||
{
|
||||
return std::make_unique<InterpreterShowCreateAccessEntityQuery>(query, context);
|
||||
}
|
||||
else if (query->as<ASTShowGrantsQuery>())
|
||||
{
|
||||
return std::make_unique<InterpreterShowGrantsQuery>(query, context);
|
||||
}
|
||||
else if (query->as<ASTShowQuotasQuery>())
|
||||
{
|
||||
return std::make_unique<InterpreterShowQuotasQuery>(query, context);
|
||||
|
57
dbms/src/Interpreters/InterpreterGrantQuery.cpp
Normal file
57
dbms/src/Interpreters/InterpreterGrantQuery.cpp
Normal file
@ -0,0 +1,57 @@
|
||||
#include <Interpreters/InterpreterGrantQuery.h>
|
||||
#include <Parsers/ASTGrantQuery.h>
|
||||
#include <Parsers/ASTRoleList.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Access/AccessControlManager.h>
|
||||
#include <Access/AccessRightsContext.h>
|
||||
#include <Access/User.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
BlockIO InterpreterGrantQuery::execute()
|
||||
{
|
||||
const auto & query = query_ptr->as<const ASTGrantQuery &>();
|
||||
auto & access_control = context.getAccessControlManager();
|
||||
context.getAccessRights()->checkGrantOption(query.access_rights_elements);
|
||||
|
||||
using Kind = ASTGrantQuery::Kind;
|
||||
|
||||
if (query.to_roles->all_roles)
|
||||
throw Exception(
|
||||
"Cannot " + String((query.kind == Kind::GRANT) ? "GRANT to" : "REVOKE from") + " ALL", ErrorCodes::NOT_IMPLEMENTED);
|
||||
|
||||
String current_database = context.getCurrentDatabase();
|
||||
|
||||
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
|
||||
{
|
||||
auto updated_user = typeid_cast<std::shared_ptr<User>>(entity->clone());
|
||||
if (query.kind == Kind::GRANT)
|
||||
{
|
||||
updated_user->access.grant(query.access_rights_elements, current_database);
|
||||
if (query.grant_option)
|
||||
updated_user->access_with_grant_option.grant(query.access_rights_elements, current_database);
|
||||
}
|
||||
else if (context.getSettingsRef().partial_revokes)
|
||||
{
|
||||
updated_user->access_with_grant_option.partialRevoke(query.access_rights_elements, current_database);
|
||||
if (!query.grant_option)
|
||||
updated_user->access.partialRevoke(query.access_rights_elements, current_database);
|
||||
}
|
||||
else
|
||||
{
|
||||
updated_user->access_with_grant_option.revoke(query.access_rights_elements, current_database);
|
||||
if (!query.grant_option)
|
||||
updated_user->access.revoke(query.access_rights_elements, current_database);
|
||||
}
|
||||
return updated_user;
|
||||
};
|
||||
|
||||
std::vector<UUID> ids = access_control.getIDs<User>(query.to_roles->roles);
|
||||
if (query.to_roles->current_user)
|
||||
ids.push_back(context.getUserID());
|
||||
access_control.update(ids, update_func);
|
||||
return {};
|
||||
}
|
||||
|
||||
}
|
20
dbms/src/Interpreters/InterpreterGrantQuery.h
Normal file
20
dbms/src/Interpreters/InterpreterGrantQuery.h
Normal file
@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include <Interpreters/IInterpreter.h>
|
||||
#include <Parsers/IAST_fwd.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
class InterpreterGrantQuery : public IInterpreter
|
||||
{
|
||||
public:
|
||||
InterpreterGrantQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {}
|
||||
|
||||
BlockIO execute() override;
|
||||
|
||||
private:
|
||||
ASTPtr query_ptr;
|
||||
Context & context;
|
||||
};
|
||||
}
|
@ -154,9 +154,7 @@ String InterpreterSelectQuery::generateFilterActions(ExpressionActionsPtr & acti
|
||||
/// Using separate expression analyzer to prevent any possible alias injection
|
||||
auto syntax_result = SyntaxAnalyzer(*context).analyze(query_ast, storage->getColumns().getAllPhysical());
|
||||
SelectQueryExpressionAnalyzer analyzer(query_ast, syntax_result, *context);
|
||||
ExpressionActionsChain new_chain(*context);
|
||||
analyzer.appendSelect(new_chain, false);
|
||||
actions = new_chain.getLastActions();
|
||||
actions = analyzer.simpleSelectActions();
|
||||
|
||||
return expr_list->children.at(0)->getColumnName();
|
||||
}
|
||||
@ -212,17 +210,6 @@ static Context getSubqueryContext(const Context & context)
|
||||
return subquery_context;
|
||||
}
|
||||
|
||||
static void sanitizeBlock(Block & block)
|
||||
{
|
||||
for (auto & col : block)
|
||||
{
|
||||
if (!col.column)
|
||||
col.column = col.type->createColumn();
|
||||
else if (isColumnConst(*col.column) && !col.column->empty())
|
||||
col.column = col.column->cloneEmpty();
|
||||
}
|
||||
}
|
||||
|
||||
InterpreterSelectQuery::InterpreterSelectQuery(
|
||||
const ASTPtr & query_ptr_,
|
||||
const Context & context_,
|
||||
@ -324,7 +311,7 @@ InterpreterSelectQuery::InterpreterSelectQuery(
|
||||
table_id = storage->getStorageID();
|
||||
}
|
||||
|
||||
auto analyze = [&] ()
|
||||
auto analyze = [&] (bool try_move_to_prewhere = true)
|
||||
{
|
||||
syntax_analyzer_result = SyntaxAnalyzer(*context, options).analyze(
|
||||
query_ptr, source_header.getNamesAndTypesList(), required_result_column_names, storage, NamesAndTypesList());
|
||||
@ -397,7 +384,7 @@ InterpreterSelectQuery::InterpreterSelectQuery(
|
||||
throw Exception("PREWHERE is not supported if the table is filtered by row-level security expression", ErrorCodes::ILLEGAL_PREWHERE);
|
||||
|
||||
/// Calculate structure of the result.
|
||||
result_header = getSampleBlockImpl();
|
||||
result_header = getSampleBlockImpl(try_move_to_prewhere);
|
||||
};
|
||||
|
||||
analyze();
|
||||
@ -425,8 +412,13 @@ InterpreterSelectQuery::InterpreterSelectQuery(
|
||||
query.setExpression(ASTSelectQuery::Expression::WHERE, makeASTFunction("and", query.prewhere()->clone(), query.where()->clone()));
|
||||
need_analyze_again = true;
|
||||
}
|
||||
|
||||
if (need_analyze_again)
|
||||
analyze();
|
||||
{
|
||||
/// Do not try move conditions to PREWHERE for the second time.
|
||||
/// Otherwise, we won't be able to fallback from inefficient PREWHERE to WHERE later.
|
||||
analyze(/* try_move_to_prewhere = */ false);
|
||||
}
|
||||
|
||||
/// If there is no WHERE, filter blocks as usual
|
||||
if (query.prewhere() && !query.where())
|
||||
@ -509,7 +501,7 @@ QueryPipeline InterpreterSelectQuery::executeWithProcessors()
|
||||
}
|
||||
|
||||
|
||||
Block InterpreterSelectQuery::getSampleBlockImpl()
|
||||
Block InterpreterSelectQuery::getSampleBlockImpl(bool try_move_to_prewhere)
|
||||
{
|
||||
auto & query = getSelectQuery();
|
||||
const Settings & settings = context->getSettingsRef();
|
||||
@ -533,7 +525,7 @@ Block InterpreterSelectQuery::getSampleBlockImpl()
|
||||
current_info.sets = query_analyzer->getPreparedSets();
|
||||
|
||||
/// Try transferring some condition from WHERE to PREWHERE if enabled and viable
|
||||
if (settings.optimize_move_to_prewhere && query.where() && !query.prewhere() && !query.final())
|
||||
if (settings.optimize_move_to_prewhere && try_move_to_prewhere && query.where() && !query.prewhere() && !query.final())
|
||||
MergeTreeWhereOptimizer{current_info, *context, merge_tree,
|
||||
syntax_analyzer_result->requiredSourceColumns(), log};
|
||||
};
|
||||
@ -546,13 +538,17 @@ Block InterpreterSelectQuery::getSampleBlockImpl()
|
||||
if (storage && !options.only_analyze)
|
||||
from_stage = storage->getQueryProcessingStage(*context);
|
||||
|
||||
analysis_result = analyzeExpressions(
|
||||
getSelectQuery(),
|
||||
/// Do I need to perform the first part of the pipeline - running on remote servers during distributed processing.
|
||||
bool first_stage = from_stage < QueryProcessingStage::WithMergeableState
|
||||
&& options.to_stage >= QueryProcessingStage::WithMergeableState;
|
||||
/// Do I need to execute the second part of the pipeline - running on the initiating server during distributed processing.
|
||||
bool second_stage = from_stage <= QueryProcessingStage::WithMergeableState
|
||||
&& options.to_stage > QueryProcessingStage::WithMergeableState;
|
||||
|
||||
analysis_result = ExpressionAnalysisResult(
|
||||
*query_analyzer,
|
||||
from_stage,
|
||||
options.to_stage,
|
||||
*context,
|
||||
storage,
|
||||
first_stage,
|
||||
second_stage,
|
||||
options.only_analyze,
|
||||
filter_info,
|
||||
source_header
|
||||
@ -579,16 +575,12 @@ Block InterpreterSelectQuery::getSampleBlockImpl()
|
||||
|
||||
auto header = analysis_result.before_aggregation->getSampleBlock();
|
||||
|
||||
Names key_names;
|
||||
AggregateDescriptions aggregates;
|
||||
query_analyzer->getAggregateInfo(key_names, aggregates);
|
||||
|
||||
Block res;
|
||||
|
||||
for (auto & key : key_names)
|
||||
res.insert({nullptr, header.getByName(key).type, key});
|
||||
for (auto & key : query_analyzer->aggregationKeys())
|
||||
res.insert({nullptr, header.getByName(key.name).type, key.name});
|
||||
|
||||
for (auto & aggregate : aggregates)
|
||||
for (auto & aggregate : query_analyzer->aggregates())
|
||||
{
|
||||
size_t arguments_size = aggregate.argument_names.size();
|
||||
DataTypes argument_types(arguments_size);
|
||||
@ -606,246 +598,6 @@ Block InterpreterSelectQuery::getSampleBlockImpl()
|
||||
return analysis_result.final_projection->getSampleBlock();
|
||||
}
|
||||
|
||||
/// Check if there is an ignore function. It's used for disabling constant folding in query
|
||||
/// predicates because some performance tests use ignore function as a non-optimize guard.
|
||||
static bool hasIgnore(const ExpressionActions & actions)
|
||||
{
|
||||
for (auto & action : actions.getActions())
|
||||
{
|
||||
if (action.type == action.APPLY_FUNCTION && action.function_base)
|
||||
{
|
||||
auto name = action.function_base->getName();
|
||||
if (name == "ignore")
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
InterpreterSelectQuery::AnalysisResult
|
||||
InterpreterSelectQuery::analyzeExpressions(
|
||||
const ASTSelectQuery & query,
|
||||
SelectQueryExpressionAnalyzer & query_analyzer,
|
||||
QueryProcessingStage::Enum from_stage,
|
||||
QueryProcessingStage::Enum to_stage,
|
||||
const Context & context,
|
||||
const StoragePtr & storage,
|
||||
bool only_types,
|
||||
const FilterInfoPtr & filter_info,
|
||||
const Block & source_header)
|
||||
{
|
||||
AnalysisResult res;
|
||||
|
||||
/// Do I need to perform the first part of the pipeline - running on remote servers during distributed processing.
|
||||
res.first_stage = from_stage < QueryProcessingStage::WithMergeableState
|
||||
&& to_stage >= QueryProcessingStage::WithMergeableState;
|
||||
/// Do I need to execute the second part of the pipeline - running on the initiating server during distributed processing.
|
||||
res.second_stage = from_stage <= QueryProcessingStage::WithMergeableState
|
||||
&& to_stage > QueryProcessingStage::WithMergeableState;
|
||||
|
||||
/** First we compose a chain of actions and remember the necessary steps from it.
|
||||
* Regardless of from_stage and to_stage, we will compose a complete sequence of actions to perform optimization and
|
||||
* throw out unnecessary columns based on the entire query. In unnecessary parts of the query, we will not execute subqueries.
|
||||
*/
|
||||
|
||||
bool has_filter = false;
|
||||
bool has_prewhere = false;
|
||||
bool has_where = false;
|
||||
size_t where_step_num;
|
||||
|
||||
auto finalizeChain = [&](ExpressionActionsChain & chain)
|
||||
{
|
||||
chain.finalize();
|
||||
|
||||
if (has_prewhere)
|
||||
{
|
||||
const ExpressionActionsChain::Step & step = chain.steps.at(0);
|
||||
res.prewhere_info->remove_prewhere_column = step.can_remove_required_output.at(0);
|
||||
|
||||
Names columns_to_remove;
|
||||
for (size_t i = 1; i < step.required_output.size(); ++i)
|
||||
{
|
||||
if (step.can_remove_required_output[i])
|
||||
columns_to_remove.push_back(step.required_output[i]);
|
||||
}
|
||||
|
||||
if (!columns_to_remove.empty())
|
||||
{
|
||||
auto columns = res.prewhere_info->prewhere_actions->getSampleBlock().getNamesAndTypesList();
|
||||
ExpressionActionsPtr actions = std::make_shared<ExpressionActions>(columns, context);
|
||||
for (const auto & column : columns_to_remove)
|
||||
actions->add(ExpressionAction::removeColumn(column));
|
||||
|
||||
res.prewhere_info->remove_columns_actions = std::move(actions);
|
||||
}
|
||||
|
||||
res.columns_to_remove_after_prewhere = std::move(columns_to_remove);
|
||||
}
|
||||
else if (has_filter)
|
||||
{
|
||||
/// Can't have prewhere and filter set simultaneously
|
||||
res.filter_info->do_remove_column = chain.steps.at(0).can_remove_required_output.at(0);
|
||||
}
|
||||
if (has_where)
|
||||
res.remove_where_filter = chain.steps.at(where_step_num).can_remove_required_output.at(0);
|
||||
|
||||
has_filter = has_prewhere = has_where = false;
|
||||
|
||||
chain.clear();
|
||||
};
|
||||
|
||||
{
|
||||
ExpressionActionsChain chain(context);
|
||||
Names additional_required_columns_after_prewhere;
|
||||
|
||||
if (storage && (query.sample_size() || context.getSettingsRef().parallel_replicas_count > 1))
|
||||
{
|
||||
Names columns_for_sampling = storage->getColumnsRequiredForSampling();
|
||||
additional_required_columns_after_prewhere.insert(additional_required_columns_after_prewhere.end(),
|
||||
columns_for_sampling.begin(), columns_for_sampling.end());
|
||||
}
|
||||
|
||||
if (storage && query.final())
|
||||
{
|
||||
Names columns_for_final = storage->getColumnsRequiredForFinal();
|
||||
additional_required_columns_after_prewhere.insert(additional_required_columns_after_prewhere.end(),
|
||||
columns_for_final.begin(), columns_for_final.end());
|
||||
}
|
||||
|
||||
if (storage && filter_info)
|
||||
{
|
||||
has_filter = true;
|
||||
res.filter_info = filter_info;
|
||||
query_analyzer.appendPreliminaryFilter(chain, filter_info->actions, filter_info->column_name);
|
||||
}
|
||||
|
||||
if (query_analyzer.appendPrewhere(chain, !res.first_stage, additional_required_columns_after_prewhere))
|
||||
{
|
||||
has_prewhere = true;
|
||||
|
||||
res.prewhere_info = std::make_shared<PrewhereInfo>(
|
||||
chain.steps.front().actions, query.prewhere()->getColumnName());
|
||||
|
||||
if (!hasIgnore(*res.prewhere_info->prewhere_actions))
|
||||
{
|
||||
Block before_prewhere_sample = source_header;
|
||||
sanitizeBlock(before_prewhere_sample);
|
||||
res.prewhere_info->prewhere_actions->execute(before_prewhere_sample);
|
||||
auto & column_elem = before_prewhere_sample.getByName(query.prewhere()->getColumnName());
|
||||
/// If the filter column is a constant, record it.
|
||||
if (column_elem.column)
|
||||
res.prewhere_constant_filter_description = ConstantFilterDescription(*column_elem.column);
|
||||
}
|
||||
chain.addStep();
|
||||
}
|
||||
|
||||
res.need_aggregate = query_analyzer.hasAggregation();
|
||||
|
||||
query_analyzer.appendArrayJoin(chain, only_types || !res.first_stage);
|
||||
|
||||
if (query_analyzer.appendJoin(chain, only_types || !res.first_stage))
|
||||
{
|
||||
res.before_join = chain.getLastActions();
|
||||
if (!res.hasJoin())
|
||||
throw Exception("No expected JOIN", ErrorCodes::LOGICAL_ERROR);
|
||||
chain.addStep();
|
||||
}
|
||||
|
||||
if (query_analyzer.appendWhere(chain, only_types || !res.first_stage))
|
||||
{
|
||||
where_step_num = chain.steps.size() - 1;
|
||||
has_where = res.has_where = true;
|
||||
res.before_where = chain.getLastActions();
|
||||
if (!hasIgnore(*res.before_where))
|
||||
{
|
||||
Block before_where_sample;
|
||||
if (chain.steps.size() > 1)
|
||||
before_where_sample = chain.steps[chain.steps.size() - 2].actions->getSampleBlock();
|
||||
else
|
||||
before_where_sample = source_header;
|
||||
sanitizeBlock(before_where_sample);
|
||||
res.before_where->execute(before_where_sample);
|
||||
auto & column_elem = before_where_sample.getByName(query.where()->getColumnName());
|
||||
/// If the filter column is a constant, record it.
|
||||
if (column_elem.column)
|
||||
res.where_constant_filter_description = ConstantFilterDescription(*column_elem.column);
|
||||
}
|
||||
chain.addStep();
|
||||
}
|
||||
|
||||
if (res.need_aggregate)
|
||||
{
|
||||
query_analyzer.appendGroupBy(chain, only_types || !res.first_stage);
|
||||
query_analyzer.appendAggregateFunctionsArguments(chain, only_types || !res.first_stage);
|
||||
res.before_aggregation = chain.getLastActions();
|
||||
|
||||
finalizeChain(chain);
|
||||
|
||||
if (query_analyzer.appendHaving(chain, only_types || !res.second_stage))
|
||||
{
|
||||
res.has_having = true;
|
||||
res.before_having = chain.getLastActions();
|
||||
chain.addStep();
|
||||
}
|
||||
}
|
||||
|
||||
bool has_stream_with_non_joned_rows = (res.before_join && res.before_join->getTableJoinAlgo()->hasStreamWithNonJoinedRows());
|
||||
res.optimize_read_in_order =
|
||||
context.getSettingsRef().optimize_read_in_order
|
||||
&& storage && query.orderBy()
|
||||
&& !query_analyzer.hasAggregation()
|
||||
&& !query.final()
|
||||
&& !has_stream_with_non_joned_rows;
|
||||
|
||||
/// If there is aggregation, we execute expressions in SELECT and ORDER BY on the initiating server, otherwise on the source servers.
|
||||
query_analyzer.appendSelect(chain, only_types || (res.need_aggregate ? !res.second_stage : !res.first_stage));
|
||||
res.selected_columns = chain.getLastStep().required_output;
|
||||
res.has_order_by = query_analyzer.appendOrderBy(chain, only_types || (res.need_aggregate ? !res.second_stage : !res.first_stage), res.optimize_read_in_order);
|
||||
res.before_order_and_select = chain.getLastActions();
|
||||
chain.addStep();
|
||||
|
||||
if (query_analyzer.appendLimitBy(chain, only_types || !res.second_stage))
|
||||
{
|
||||
res.has_limit_by = true;
|
||||
res.before_limit_by = chain.getLastActions();
|
||||
chain.addStep();
|
||||
}
|
||||
|
||||
query_analyzer.appendProjectResult(chain);
|
||||
res.final_projection = chain.getLastActions();
|
||||
|
||||
finalizeChain(chain);
|
||||
}
|
||||
|
||||
/// Before executing WHERE and HAVING, remove the extra columns from the block (mostly the aggregation keys).
|
||||
if (res.filter_info)
|
||||
res.filter_info->actions->prependProjectInput();
|
||||
if (res.has_where)
|
||||
res.before_where->prependProjectInput();
|
||||
if (res.has_having)
|
||||
res.before_having->prependProjectInput();
|
||||
|
||||
res.subqueries_for_sets = query_analyzer.getSubqueriesForSets();
|
||||
|
||||
/// Check that PREWHERE doesn't contain unusual actions. Unusual actions are that can change number of rows.
|
||||
if (res.prewhere_info)
|
||||
{
|
||||
auto check_actions = [](const ExpressionActionsPtr & actions)
|
||||
{
|
||||
if (actions)
|
||||
for (const auto & action : actions->getActions())
|
||||
if (action.type == ExpressionAction::Type::JOIN || action.type == ExpressionAction::Type::ARRAY_JOIN)
|
||||
throw Exception("PREWHERE cannot contain ARRAY JOIN or JOIN action", ErrorCodes::ILLEGAL_PREWHERE);
|
||||
};
|
||||
|
||||
check_actions(res.prewhere_info->prewhere_actions);
|
||||
check_actions(res.prewhere_info->alias_actions);
|
||||
check_actions(res.prewhere_info->remove_columns_actions);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static Field getWithFillFieldValue(const ASTPtr & node, const Context & context)
|
||||
{
|
||||
const auto & [field, type] = evaluateConstantExpression(node, context);
|
||||
@ -989,6 +741,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
|
||||
auto & query = getSelectQuery();
|
||||
const Settings & settings = context->getSettingsRef();
|
||||
auto & expressions = analysis_result;
|
||||
auto & subqueries_for_sets = query_analyzer->getSubqueriesForSets();
|
||||
|
||||
if (options.only_analyze)
|
||||
{
|
||||
@ -1077,7 +830,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
|
||||
|
||||
if (expressions.first_stage)
|
||||
{
|
||||
if (expressions.filter_info)
|
||||
if (expressions.hasFilter())
|
||||
{
|
||||
if constexpr (pipeline_with_processors)
|
||||
{
|
||||
@ -1159,7 +912,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
|
||||
}
|
||||
}
|
||||
|
||||
if (expressions.has_where)
|
||||
if (expressions.hasWhere())
|
||||
executeWhere(pipeline, expressions.before_where, expressions.remove_where_filter);
|
||||
|
||||
if (expressions.need_aggregate)
|
||||
@ -1175,7 +928,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
|
||||
* but there is an ORDER or LIMIT,
|
||||
* then we will perform the preliminary sorting and LIMIT on the remote server.
|
||||
*/
|
||||
if (!expressions.second_stage && !expressions.need_aggregate && !expressions.has_having)
|
||||
if (!expressions.second_stage && !expressions.need_aggregate && !expressions.hasHaving())
|
||||
{
|
||||
if (expressions.has_order_by)
|
||||
executeOrder(pipeline, query_info.input_sorting_info);
|
||||
@ -1183,7 +936,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
|
||||
if (expressions.has_order_by && query.limitLength())
|
||||
executeDistinct(pipeline, false, expressions.selected_columns);
|
||||
|
||||
if (expressions.has_limit_by)
|
||||
if (expressions.hasLimitBy())
|
||||
{
|
||||
executeExpression(pipeline, expressions.before_limit_by);
|
||||
executeLimitBy(pipeline);
|
||||
@ -1194,8 +947,8 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
|
||||
}
|
||||
|
||||
// If there is no global subqueries, we can run subqueries only when receive them on server.
|
||||
if (!query_analyzer->hasGlobalSubqueries() && !expressions.subqueries_for_sets.empty())
|
||||
executeSubqueriesInSetsAndJoins(pipeline, expressions.subqueries_for_sets);
|
||||
if (!query_analyzer->hasGlobalSubqueries() && !subqueries_for_sets.empty())
|
||||
executeSubqueriesInSetsAndJoins(pipeline, subqueries_for_sets);
|
||||
}
|
||||
|
||||
if (expressions.second_stage)
|
||||
@ -1213,7 +966,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
|
||||
if (query.group_by_with_totals)
|
||||
{
|
||||
bool final = !query.group_by_with_rollup && !query.group_by_with_cube;
|
||||
executeTotalsAndHaving(pipeline, expressions.has_having, expressions.before_having, aggregate_overflow_row, final);
|
||||
executeTotalsAndHaving(pipeline, expressions.hasHaving(), expressions.before_having, aggregate_overflow_row, final);
|
||||
}
|
||||
|
||||
if (query.group_by_with_rollup)
|
||||
@ -1221,14 +974,14 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
|
||||
else if (query.group_by_with_cube)
|
||||
executeRollupOrCube(pipeline, Modificator::CUBE);
|
||||
|
||||
if ((query.group_by_with_rollup || query.group_by_with_cube) && expressions.has_having)
|
||||
if ((query.group_by_with_rollup || query.group_by_with_cube) && expressions.hasHaving())
|
||||
{
|
||||
if (query.group_by_with_totals)
|
||||
throw Exception("WITH TOTALS and WITH ROLLUP or CUBE are not supported together in presence of HAVING", ErrorCodes::NOT_IMPLEMENTED);
|
||||
executeHaving(pipeline, expressions.before_having);
|
||||
}
|
||||
}
|
||||
else if (expressions.has_having)
|
||||
else if (expressions.hasHaving())
|
||||
executeHaving(pipeline, expressions.before_having);
|
||||
|
||||
executeExpression(pipeline, expressions.before_order_and_select);
|
||||
@ -1256,7 +1009,8 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
|
||||
/** Optimization - if there are several sources and there is LIMIT, then first apply the preliminary LIMIT,
|
||||
* limiting the number of rows in each up to `offset + limit`.
|
||||
*/
|
||||
if (query.limitLength() && !query.limit_with_ties && pipeline.hasMoreThanOneStream() && !query.distinct && !expressions.has_limit_by && !settings.extremes)
|
||||
if (query.limitLength() && !query.limit_with_ties && pipeline.hasMoreThanOneStream() &&
|
||||
!query.distinct && !expressions.hasLimitBy() && !settings.extremes)
|
||||
{
|
||||
executePreLimit(pipeline);
|
||||
}
|
||||
@ -1281,7 +1035,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
|
||||
if (need_second_distinct_pass)
|
||||
executeDistinct(pipeline, false, expressions.selected_columns);
|
||||
|
||||
if (expressions.has_limit_by)
|
||||
if (expressions.hasLimitBy())
|
||||
{
|
||||
executeExpression(pipeline, expressions.before_limit_by);
|
||||
executeLimitBy(pipeline);
|
||||
@ -1301,8 +1055,8 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
|
||||
}
|
||||
}
|
||||
|
||||
if (query_analyzer->hasGlobalSubqueries() && !expressions.subqueries_for_sets.empty())
|
||||
executeSubqueriesInSetsAndJoins(pipeline, expressions.subqueries_for_sets);
|
||||
if (query_analyzer->hasGlobalSubqueries() && !subqueries_for_sets.empty())
|
||||
executeSubqueriesInSetsAndJoins(pipeline, subqueries_for_sets);
|
||||
}
|
||||
|
||||
template <typename TPipeline>
|
||||
@ -1324,9 +1078,7 @@ void InterpreterSelectQuery::executeFetchColumns(
|
||||
|| !query_analyzer->hasAggregation() || processing_stage != QueryProcessingStage::FetchColumns)
|
||||
return {};
|
||||
|
||||
Names key_names;
|
||||
AggregateDescriptions aggregates;
|
||||
query_analyzer->getAggregateInfo(key_names, aggregates);
|
||||
const AggregateDescriptions & aggregates = query_analyzer->aggregates();
|
||||
|
||||
if (aggregates.size() != 1)
|
||||
return {};
|
||||
@ -1639,7 +1391,7 @@ void InterpreterSelectQuery::executeFetchColumns(
|
||||
if (analysis_result.optimize_read_in_order)
|
||||
{
|
||||
query_info.order_by_optimizer = std::make_shared<ReadInOrderOptimizer>(
|
||||
query_analyzer->getOrderByActions(),
|
||||
analysis_result.order_by_elements_actions,
|
||||
getSortDescription(query, *context),
|
||||
query_info.syntax_analyzer_result);
|
||||
|
||||
@ -1866,14 +1618,12 @@ void InterpreterSelectQuery::executeAggregation(Pipeline & pipeline, const Expre
|
||||
stream = std::make_shared<ExpressionBlockInputStream>(stream, expression);
|
||||
});
|
||||
|
||||
Names key_names;
|
||||
AggregateDescriptions aggregates;
|
||||
query_analyzer->getAggregateInfo(key_names, aggregates);
|
||||
|
||||
Block header = pipeline.firstStream()->getHeader();
|
||||
ColumnNumbers keys;
|
||||
for (const auto & name : key_names)
|
||||
keys.push_back(header.getPositionByName(name));
|
||||
for (const auto & key : query_analyzer->aggregationKeys())
|
||||
keys.push_back(header.getPositionByName(key.name));
|
||||
|
||||
AggregateDescriptions aggregates = query_analyzer->aggregates();
|
||||
for (auto & descr : aggregates)
|
||||
if (descr.arguments.empty())
|
||||
for (const auto & name : descr.argument_names)
|
||||
@ -1932,14 +1682,12 @@ void InterpreterSelectQuery::executeAggregation(QueryPipeline & pipeline, const
|
||||
return std::make_shared<ExpressionTransform>(header, expression);
|
||||
});
|
||||
|
||||
Names key_names;
|
||||
AggregateDescriptions aggregates;
|
||||
query_analyzer->getAggregateInfo(key_names, aggregates);
|
||||
|
||||
Block header_before_aggregation = pipeline.getHeader();
|
||||
ColumnNumbers keys;
|
||||
for (const auto & name : key_names)
|
||||
keys.push_back(header_before_aggregation.getPositionByName(name));
|
||||
for (const auto & key : query_analyzer->aggregationKeys())
|
||||
keys.push_back(header_before_aggregation.getPositionByName(key.name));
|
||||
|
||||
AggregateDescriptions aggregates = query_analyzer->aggregates();
|
||||
for (auto & descr : aggregates)
|
||||
if (descr.arguments.empty())
|
||||
for (const auto & name : descr.argument_names)
|
||||
@ -2000,15 +1748,11 @@ void InterpreterSelectQuery::executeAggregation(QueryPipeline & pipeline, const
|
||||
|
||||
void InterpreterSelectQuery::executeMergeAggregated(Pipeline & pipeline, bool overflow_row, bool final)
|
||||
{
|
||||
Names key_names;
|
||||
AggregateDescriptions aggregates;
|
||||
query_analyzer->getAggregateInfo(key_names, aggregates);
|
||||
|
||||
Block header = pipeline.firstStream()->getHeader();
|
||||
|
||||
ColumnNumbers keys;
|
||||
for (const auto & name : key_names)
|
||||
keys.push_back(header.getPositionByName(name));
|
||||
for (const auto & key : query_analyzer->aggregationKeys())
|
||||
keys.push_back(header.getPositionByName(key.name));
|
||||
|
||||
/** There are two modes of distributed aggregation.
|
||||
*
|
||||
@ -2027,7 +1771,7 @@ void InterpreterSelectQuery::executeMergeAggregated(Pipeline & pipeline, bool ov
|
||||
|
||||
const Settings & settings = context->getSettingsRef();
|
||||
|
||||
Aggregator::Params params(header, keys, aggregates, overflow_row, settings.max_threads);
|
||||
Aggregator::Params params(header, keys, query_analyzer->aggregates(), overflow_row, settings.max_threads);
|
||||
|
||||
if (!settings.distributed_aggregation_memory_efficient)
|
||||
{
|
||||
@ -2051,15 +1795,11 @@ void InterpreterSelectQuery::executeMergeAggregated(Pipeline & pipeline, bool ov
|
||||
|
||||
void InterpreterSelectQuery::executeMergeAggregated(QueryPipeline & pipeline, bool overflow_row, bool final)
|
||||
{
|
||||
Names key_names;
|
||||
AggregateDescriptions aggregates;
|
||||
query_analyzer->getAggregateInfo(key_names, aggregates);
|
||||
|
||||
Block header_before_merge = pipeline.getHeader();
|
||||
|
||||
ColumnNumbers keys;
|
||||
for (const auto & name : key_names)
|
||||
keys.push_back(header_before_merge.getPositionByName(name));
|
||||
for (const auto & key : query_analyzer->aggregationKeys())
|
||||
keys.push_back(header_before_merge.getPositionByName(key.name));
|
||||
|
||||
/** There are two modes of distributed aggregation.
|
||||
*
|
||||
@ -2078,7 +1818,7 @@ void InterpreterSelectQuery::executeMergeAggregated(QueryPipeline & pipeline, bo
|
||||
|
||||
const Settings & settings = context->getSettingsRef();
|
||||
|
||||
Aggregator::Params params(header_before_merge, keys, aggregates, overflow_row, settings.max_threads);
|
||||
Aggregator::Params params(header_before_merge, keys, query_analyzer->aggregates(), overflow_row, settings.max_threads);
|
||||
|
||||
auto transform_params = std::make_shared<AggregatingTransformParams>(params, final);
|
||||
|
||||
@ -2167,20 +1907,16 @@ void InterpreterSelectQuery::executeRollupOrCube(Pipeline & pipeline, Modificato
|
||||
{
|
||||
executeUnion(pipeline, {});
|
||||
|
||||
Names key_names;
|
||||
AggregateDescriptions aggregates;
|
||||
query_analyzer->getAggregateInfo(key_names, aggregates);
|
||||
|
||||
Block header = pipeline.firstStream()->getHeader();
|
||||
|
||||
ColumnNumbers keys;
|
||||
|
||||
for (const auto & name : key_names)
|
||||
keys.push_back(header.getPositionByName(name));
|
||||
for (const auto & key : query_analyzer->aggregationKeys())
|
||||
keys.push_back(header.getPositionByName(key.name));
|
||||
|
||||
const Settings & settings = context->getSettingsRef();
|
||||
|
||||
Aggregator::Params params(header, keys, aggregates,
|
||||
Aggregator::Params params(header, keys, query_analyzer->aggregates(),
|
||||
false, settings.max_rows_to_group_by, settings.group_by_overflow_mode,
|
||||
SettingUInt64(0), SettingUInt64(0),
|
||||
settings.max_bytes_before_external_group_by, settings.empty_result_for_aggregation_by_empty_set,
|
||||
@ -2196,20 +1932,16 @@ void InterpreterSelectQuery::executeRollupOrCube(QueryPipeline & pipeline, Modif
|
||||
{
|
||||
pipeline.resize(1);
|
||||
|
||||
Names key_names;
|
||||
AggregateDescriptions aggregates;
|
||||
query_analyzer->getAggregateInfo(key_names, aggregates);
|
||||
|
||||
Block header_before_transform = pipeline.getHeader();
|
||||
|
||||
ColumnNumbers keys;
|
||||
|
||||
for (const auto & name : key_names)
|
||||
keys.push_back(header_before_transform.getPositionByName(name));
|
||||
for (const auto & key : query_analyzer->aggregationKeys())
|
||||
keys.push_back(header_before_transform.getPositionByName(key.name));
|
||||
|
||||
const Settings & settings = context->getSettingsRef();
|
||||
|
||||
Aggregator::Params params(header_before_transform, keys, aggregates,
|
||||
Aggregator::Params params(header_before_transform, keys, query_analyzer->aggregates(),
|
||||
false, settings.max_rows_to_group_by, settings.group_by_overflow_mode,
|
||||
SettingUInt64(0), SettingUInt64(0),
|
||||
settings.max_bytes_before_external_group_by, settings.empty_result_for_aggregation_by_empty_set,
|
||||
@ -2806,7 +2538,7 @@ void InterpreterSelectQuery::executeExtremes(QueryPipeline & pipeline)
|
||||
}
|
||||
|
||||
|
||||
void InterpreterSelectQuery::executeSubqueriesInSetsAndJoins(Pipeline & pipeline, SubqueriesForSets & subqueries_for_sets)
|
||||
void InterpreterSelectQuery::executeSubqueriesInSetsAndJoins(Pipeline & pipeline, const SubqueriesForSets & subqueries_for_sets)
|
||||
{
|
||||
/// Merge streams to one. Use MergeSorting if data was read in sorted order, Union otherwise.
|
||||
if (query_info.input_sorting_info)
|
||||
@ -2822,7 +2554,7 @@ void InterpreterSelectQuery::executeSubqueriesInSetsAndJoins(Pipeline & pipeline
|
||||
pipeline.firstStream(), subqueries_for_sets, *context);
|
||||
}
|
||||
|
||||
void InterpreterSelectQuery::executeSubqueriesInSetsAndJoins(QueryPipeline & pipeline, SubqueriesForSets & subqueries_for_sets)
|
||||
void InterpreterSelectQuery::executeSubqueriesInSetsAndJoins(QueryPipeline & pipeline, const SubqueriesForSets & subqueries_for_sets)
|
||||
{
|
||||
if (query_info.input_sorting_info)
|
||||
executeMergeSorted(pipeline, query_info.input_sorting_info->order_key_prefix_descr, 0);
|
||||
|
@ -104,7 +104,7 @@ private:
|
||||
|
||||
ASTSelectQuery & getSelectQuery() { return query_ptr->as<ASTSelectQuery &>(); }
|
||||
|
||||
Block getSampleBlockImpl();
|
||||
Block getSampleBlockImpl(bool try_move_to_prewhere);
|
||||
|
||||
struct Pipeline
|
||||
{
|
||||
@ -152,55 +152,6 @@ private:
|
||||
template <typename TPipeline>
|
||||
void executeImpl(TPipeline & pipeline, const BlockInputStreamPtr & prepared_input, std::optional<Pipe> prepared_pipe, QueryPipeline & save_context_and_storage);
|
||||
|
||||
struct AnalysisResult
|
||||
{
|
||||
bool hasJoin() const { return before_join.get(); }
|
||||
bool has_where = false;
|
||||
bool need_aggregate = false;
|
||||
bool has_having = false;
|
||||
bool has_order_by = false;
|
||||
bool has_limit_by = false;
|
||||
|
||||
bool remove_where_filter = false;
|
||||
bool optimize_read_in_order = false;
|
||||
|
||||
ExpressionActionsPtr before_join; /// including JOIN
|
||||
ExpressionActionsPtr before_where;
|
||||
ExpressionActionsPtr before_aggregation;
|
||||
ExpressionActionsPtr before_having;
|
||||
ExpressionActionsPtr before_order_and_select;
|
||||
ExpressionActionsPtr before_limit_by;
|
||||
ExpressionActionsPtr final_projection;
|
||||
|
||||
/// Columns from the SELECT list, before renaming them to aliases.
|
||||
Names selected_columns;
|
||||
|
||||
/// Columns will be removed after prewhere actions execution.
|
||||
Names columns_to_remove_after_prewhere;
|
||||
|
||||
/// Do I need to perform the first part of the pipeline - running on remote servers during distributed processing.
|
||||
bool first_stage = false;
|
||||
/// Do I need to execute the second part of the pipeline - running on the initiating server during distributed processing.
|
||||
bool second_stage = false;
|
||||
|
||||
SubqueriesForSets subqueries_for_sets;
|
||||
PrewhereInfoPtr prewhere_info;
|
||||
FilterInfoPtr filter_info;
|
||||
ConstantFilterDescription prewhere_constant_filter_description;
|
||||
ConstantFilterDescription where_constant_filter_description;
|
||||
};
|
||||
|
||||
static AnalysisResult analyzeExpressions(
|
||||
const ASTSelectQuery & query,
|
||||
SelectQueryExpressionAnalyzer & query_analyzer,
|
||||
QueryProcessingStage::Enum from_stage,
|
||||
QueryProcessingStage::Enum to_stage,
|
||||
const Context & context,
|
||||
const StoragePtr & storage,
|
||||
bool only_types,
|
||||
const FilterInfoPtr & filter_info,
|
||||
const Block & source_header);
|
||||
|
||||
/** From which table to read. With JOIN, the "left" table is returned.
|
||||
*/
|
||||
static void getDatabaseAndTableNames(const ASTSelectQuery & query, String & database_name, String & table_name, const Context & context);
|
||||
@ -232,7 +183,7 @@ private:
|
||||
void executeProjection(Pipeline & pipeline, const ExpressionActionsPtr & expression);
|
||||
void executeDistinct(Pipeline & pipeline, bool before_order, Names columns);
|
||||
void executeExtremes(Pipeline & pipeline);
|
||||
void executeSubqueriesInSetsAndJoins(Pipeline & pipeline, std::unordered_map<String, SubqueryForSet> & subqueries_for_sets);
|
||||
void executeSubqueriesInSetsAndJoins(Pipeline & pipeline, const std::unordered_map<String, SubqueryForSet> & subqueries_for_sets);
|
||||
void executeMergeSorted(Pipeline & pipeline, const SortDescription & sort_description, UInt64 limit);
|
||||
|
||||
void executeWhere(QueryPipeline & pipeline, const ExpressionActionsPtr & expression, bool remove_fiter);
|
||||
@ -250,7 +201,7 @@ private:
|
||||
void executeProjection(QueryPipeline & pipeline, const ExpressionActionsPtr & expression);
|
||||
void executeDistinct(QueryPipeline & pipeline, bool before_order, Names columns);
|
||||
void executeExtremes(QueryPipeline & pipeline);
|
||||
void executeSubqueriesInSetsAndJoins(QueryPipeline & pipeline, std::unordered_map<String, SubqueryForSet> & subqueries_for_sets);
|
||||
void executeSubqueriesInSetsAndJoins(QueryPipeline & pipeline, const std::unordered_map<String, SubqueryForSet> & subqueries_for_sets);
|
||||
void executeMergeSorted(QueryPipeline & pipeline, const SortDescription & sort_description, UInt64 limit);
|
||||
|
||||
String generateFilterActions(ExpressionActionsPtr & actions, const ASTPtr & row_policy_filter, const Names & prerequisite_columns = {}) const;
|
||||
@ -284,7 +235,7 @@ private:
|
||||
SelectQueryInfo query_info;
|
||||
|
||||
/// Is calculated in getSampleBlock. Is used later in readImpl.
|
||||
AnalysisResult analysis_result;
|
||||
ExpressionAnalysisResult analysis_result;
|
||||
FilterInfoPtr filter_info;
|
||||
|
||||
QueryProcessingStage::Enum from_stage = QueryProcessingStage::FetchColumns;
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include <Interpreters/InterpreterShowCreateAccessEntityQuery.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Parsers/ASTCreateUserQuery.h>
|
||||
#include <Parsers/ASTCreateQuotaQuery.h>
|
||||
#include <Parsers/ASTCreateRowPolicyQuery.h>
|
||||
#include <Parsers/ASTShowCreateAccessEntityQuery.h>
|
||||
@ -9,6 +10,7 @@
|
||||
#include <Parsers/parseQuery.h>
|
||||
#include <Access/AccessControlManager.h>
|
||||
#include <Access/QuotaContext.h>
|
||||
#include <Access/User.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <DataStreams/OneBlockInputStream.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
@ -58,6 +60,7 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuery(const ASTShowCreat
|
||||
using Kind = ASTShowCreateAccessEntityQuery::Kind;
|
||||
switch (show_query.kind)
|
||||
{
|
||||
case Kind::USER: return getCreateUserQuery(show_query);
|
||||
case Kind::QUOTA: return getCreateQuotaQuery(show_query);
|
||||
case Kind::ROW_POLICY: return getCreateRowPolicyQuery(show_query);
|
||||
}
|
||||
@ -65,6 +68,27 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuery(const ASTShowCreat
|
||||
}
|
||||
|
||||
|
||||
ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateUserQuery(const ASTShowCreateAccessEntityQuery & show_query) const
|
||||
{
|
||||
UserPtr user;
|
||||
if (show_query.current_user)
|
||||
user = context.getUser();
|
||||
else
|
||||
user = context.getAccessControlManager().getUser(show_query.name);
|
||||
|
||||
auto create_query = std::make_shared<ASTCreateUserQuery>();
|
||||
create_query->name = user->getName();
|
||||
|
||||
if (!user->allowed_client_hosts.containsAnyHost())
|
||||
create_query->hosts = user->allowed_client_hosts;
|
||||
|
||||
if (!user->profile.empty())
|
||||
create_query->profile = user->profile;
|
||||
|
||||
return create_query;
|
||||
}
|
||||
|
||||
|
||||
ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuotaQuery(const ASTShowCreateAccessEntityQuery & show_query) const
|
||||
{
|
||||
auto & access_control = context.getAccessControlManager();
|
||||
|
@ -29,6 +29,7 @@ private:
|
||||
|
||||
BlockInputStreamPtr executeImpl();
|
||||
ASTPtr getCreateQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
|
||||
ASTPtr getCreateUserQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
|
||||
ASTPtr getCreateQuotaQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
|
||||
ASTPtr getCreateRowPolicyQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
|
||||
};
|
||||
|
124
dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp
Normal file
124
dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp
Normal file
@ -0,0 +1,124 @@
|
||||
#include <Interpreters/InterpreterShowGrantsQuery.h>
|
||||
#include <Parsers/ASTShowGrantsQuery.h>
|
||||
#include <Parsers/ASTGrantQuery.h>
|
||||
#include <Parsers/ASTRoleList.h>
|
||||
#include <Parsers/formatAST.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Columns/ColumnString.h>
|
||||
#include <DataStreams/OneBlockInputStream.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <Access/AccessControlManager.h>
|
||||
#include <Access/User.h>
|
||||
#include <boost/range/adaptor/map.hpp>
|
||||
#include <boost/range/algorithm/copy.hpp>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace
|
||||
{
|
||||
std::vector<AccessRightsElements> groupByTable(AccessRightsElements && elements)
|
||||
{
|
||||
using Key = std::tuple<String, bool, String, bool>;
|
||||
std::map<Key, AccessRightsElements> grouping_map;
|
||||
for (auto & element : elements)
|
||||
{
|
||||
Key key(element.database, element.any_database, element.table, element.any_table);
|
||||
grouping_map[key].emplace_back(std::move(element));
|
||||
}
|
||||
std::vector<AccessRightsElements> res;
|
||||
res.reserve(grouping_map.size());
|
||||
boost::range::copy(grouping_map | boost::adaptors::map_values, std::back_inserter(res));
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
struct GroupedGrantsAndPartialRevokes
|
||||
{
|
||||
std::vector<AccessRightsElements> grants;
|
||||
std::vector<AccessRightsElements> partial_revokes;
|
||||
};
|
||||
|
||||
GroupedGrantsAndPartialRevokes groupByTable(AccessRights::Elements && elements)
|
||||
{
|
||||
GroupedGrantsAndPartialRevokes res;
|
||||
res.grants = groupByTable(std::move(elements.grants));
|
||||
res.partial_revokes = groupByTable(std::move(elements.partial_revokes));
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
BlockIO InterpreterShowGrantsQuery::execute()
|
||||
{
|
||||
BlockIO res;
|
||||
res.in = executeImpl();
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
BlockInputStreamPtr InterpreterShowGrantsQuery::executeImpl()
|
||||
{
|
||||
const auto & show_query = query_ptr->as<ASTShowGrantsQuery &>();
|
||||
|
||||
/// Build a create query.
|
||||
ASTs grant_queries = getGrantQueries(show_query);
|
||||
|
||||
/// Build the result column.
|
||||
MutableColumnPtr column = ColumnString::create();
|
||||
std::stringstream grant_ss;
|
||||
for (const auto & grant_query : grant_queries)
|
||||
{
|
||||
grant_ss.str("");
|
||||
formatAST(*grant_query, grant_ss, false, true);
|
||||
column->insert(grant_ss.str());
|
||||
}
|
||||
|
||||
/// Prepare description of the result column.
|
||||
std::stringstream desc_ss;
|
||||
formatAST(show_query, desc_ss, false, true);
|
||||
String desc = desc_ss.str();
|
||||
String prefix = "SHOW ";
|
||||
if (desc.starts_with(prefix))
|
||||
desc = desc.substr(prefix.length()); /// `desc` always starts with "SHOW ", so we can trim this prefix.
|
||||
|
||||
return std::make_shared<OneBlockInputStream>(Block{{std::move(column), std::make_shared<DataTypeString>(), desc}});
|
||||
}
|
||||
|
||||
|
||||
ASTs InterpreterShowGrantsQuery::getGrantQueries(const ASTShowGrantsQuery & show_query) const
|
||||
{
|
||||
UserPtr user;
|
||||
if (show_query.current_user)
|
||||
user = context.getUser();
|
||||
else
|
||||
user = context.getAccessControlManager().getUser(show_query.name);
|
||||
|
||||
ASTs res;
|
||||
|
||||
for (bool grant_option : {true, false})
|
||||
{
|
||||
if (!grant_option && (user->access == user->access_with_grant_option))
|
||||
continue;
|
||||
const auto & access_rights = grant_option ? user->access_with_grant_option : user->access;
|
||||
const auto grouped_elements = groupByTable(access_rights.getElements());
|
||||
|
||||
using Kind = ASTGrantQuery::Kind;
|
||||
for (Kind kind : {Kind::GRANT, Kind::REVOKE})
|
||||
{
|
||||
for (const auto & elements : (kind == Kind::GRANT ? grouped_elements.grants : grouped_elements.partial_revokes))
|
||||
{
|
||||
auto grant_query = std::make_shared<ASTGrantQuery>();
|
||||
grant_query->kind = kind;
|
||||
grant_query->grant_option = grant_option;
|
||||
grant_query->to_roles = std::make_shared<ASTRoleList>();
|
||||
grant_query->to_roles->roles.push_back(user->getName());
|
||||
grant_query->access_rights_elements = elements;
|
||||
res.push_back(std::move(grant_query));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
}
|
26
dbms/src/Interpreters/InterpreterShowGrantsQuery.h
Normal file
26
dbms/src/Interpreters/InterpreterShowGrantsQuery.h
Normal file
@ -0,0 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include <Interpreters/IInterpreter.h>
|
||||
#include <Parsers/IAST_fwd.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
class ASTShowGrantsQuery;
|
||||
|
||||
|
||||
class InterpreterShowGrantsQuery : public IInterpreter
|
||||
{
|
||||
public:
|
||||
InterpreterShowGrantsQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {}
|
||||
|
||||
BlockIO execute() override;
|
||||
|
||||
private:
|
||||
BlockInputStreamPtr executeImpl();
|
||||
ASTs getGrantQueries(const ASTShowGrantsQuery & show_query) const;
|
||||
|
||||
ASTPtr query_ptr;
|
||||
Context & context;
|
||||
};
|
||||
}
|
@ -18,6 +18,7 @@
|
||||
#include <Parsers/ASTExpressionList.h>
|
||||
#include <Parsers/ASTSelectQuery.h>
|
||||
#include <Parsers/formatAST.h>
|
||||
#include <Parsers/ASTOrderByElement.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
|
||||
@ -525,6 +526,39 @@ ASTPtr MutationsInterpreter::prepareInterpreterSelectQuery(std::vector<Stage> &
|
||||
}
|
||||
select->setExpression(ASTSelectQuery::Expression::WHERE, std::move(where_expression));
|
||||
}
|
||||
auto metadata = storage->getInMemoryMetadata();
|
||||
/// We have to execute select in order of primary key
|
||||
/// because we don't sort results additionaly and don't have
|
||||
/// any guarantees on data order without ORDER BY. It's almost free, because we
|
||||
/// have optimization for data read in primary key order.
|
||||
if (metadata.order_by_ast)
|
||||
{
|
||||
ASTPtr dummy;
|
||||
|
||||
ASTPtr key_expr;
|
||||
if (metadata.primary_key_ast)
|
||||
key_expr = metadata.primary_key_ast;
|
||||
else
|
||||
key_expr = metadata.order_by_ast;
|
||||
|
||||
bool empty = false;
|
||||
/// In all other cases we cannot have empty key
|
||||
if (auto key_function = key_expr->as<ASTFunction>())
|
||||
empty = key_function->arguments->children.size() == 0;
|
||||
|
||||
/// Not explicitely spicified empty key
|
||||
if (!empty)
|
||||
{
|
||||
auto order_by_expr = std::make_shared<ASTOrderByElement>(1, 1, false, dummy, false, dummy, dummy, dummy);
|
||||
|
||||
|
||||
order_by_expr->children.push_back(key_expr);
|
||||
auto res = std::make_shared<ASTExpressionList>();
|
||||
res->children.push_back(order_by_expr);
|
||||
|
||||
select->setExpression(ASTSelectQuery::Expression::ORDER_BY, std::move(res));
|
||||
}
|
||||
}
|
||||
|
||||
return select;
|
||||
}
|
||||
|
@ -181,12 +181,12 @@ ProcessList::EntryPtr ProcessList::insert(const String & query_, const IAST * as
|
||||
/// You should specify this value in configuration for default profile,
|
||||
/// not for specific users, sessions or queries,
|
||||
/// because this setting is effectively global.
|
||||
total_memory_tracker.setOrRaiseLimit(settings.max_memory_usage_for_all_queries);
|
||||
total_memory_tracker.setOrRaiseHardLimit(settings.max_memory_usage_for_all_queries);
|
||||
total_memory_tracker.setDescription("(total)");
|
||||
|
||||
/// Track memory usage for all simultaneously running queries from single user.
|
||||
user_process_list.user_memory_tracker.setParent(&total_memory_tracker);
|
||||
user_process_list.user_memory_tracker.setOrRaiseLimit(settings.max_memory_usage_for_user);
|
||||
user_process_list.user_memory_tracker.setOrRaiseHardLimit(settings.max_memory_usage_for_user);
|
||||
user_process_list.user_memory_tracker.setDescription("(for user)");
|
||||
|
||||
/// Actualize thread group info
|
||||
@ -198,7 +198,9 @@ ProcessList::EntryPtr ProcessList::insert(const String & query_, const IAST * as
|
||||
thread_group->query = process_it->query;
|
||||
|
||||
/// Set query-level memory trackers
|
||||
thread_group->memory_tracker.setOrRaiseLimit(process_it->max_memory_usage);
|
||||
thread_group->memory_tracker.setOrRaiseHardLimit(process_it->max_memory_usage);
|
||||
thread_group->memory_tracker.setOrRaiseProfilerLimit(settings.memory_profiler_step);
|
||||
thread_group->memory_tracker.setProfilerStep(settings.memory_profiler_step);
|
||||
thread_group->memory_tracker.setDescription("(for query)");
|
||||
if (process_it->memory_tracker_fault_probability)
|
||||
thread_group->memory_tracker.setFaultProbability(process_it->memory_tracker_fault_probability);
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <Interpreters/InterpreterSelectWithUnionQuery.h>
|
||||
#include <Interpreters/Join.h>
|
||||
#include <Interpreters/MergeJoin.h>
|
||||
#include <Interpreters/ExpressionActions.h>
|
||||
#include <DataStreams/LazyBlockInputStream.h>
|
||||
|
||||
namespace DB
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user