Merge branch 'master' of github.com:yandex/ClickHouse

This commit is contained in:
Alexey Milovidov 2020-02-13 14:45:04 +03:00
commit 0eade98688
381 changed files with 9021 additions and 3095 deletions

View File

@ -52,12 +52,12 @@ IncludeCategories:
ReflowComments: false
AlignEscapedNewlinesLeft: false
AlignEscapedNewlines: DontAlign
AlignTrailingComments: true
# Not changed:
AccessModifierOffset: -4
AlignConsecutiveAssignments: false
AlignOperands: false
AlignTrailingComments: false
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false

View File

@ -11,8 +11,3 @@ ClickHouse is an open-source column-oriented database management system that all
* [Blog](https://clickhouse.yandex/blog/en/) contains various ClickHouse-related articles, as well as announces and reports about events.
* [Contacts](https://clickhouse.yandex/#contacts) can help to get your questions answered if there are any.
* You can also [fill this form](https://forms.yandex.com/surveys/meet-yandex-clickhouse-team/) to meet Yandex ClickHouse team in person.
## Upcoming Events
* [ClickHouse Meetup in San Francisco](https://www.eventbrite.com/e/clickhouse-february-meetup-registration-88496227599) on February 5.
* [ClickHouse Meetup in New York](https://www.meetup.com/Uber-Engineering-Events-New-York/events/268328663/) on February 11.

View File

@ -1,4 +1,4 @@
if (NOT APPLE AND NOT ARCH_32)
if (NOT ARCH_32)
option (USE_INTERNAL_LIBGSASL_LIBRARY "Set to FALSE to use system libgsasl library instead of bundled" ${NOT_UNBUNDLED})
endif ()
@ -16,7 +16,7 @@ if (NOT USE_INTERNAL_LIBGSASL_LIBRARY)
endif ()
if (LIBGSASL_LIBRARY AND LIBGSASL_INCLUDE_DIR)
elseif (NOT MISSING_INTERNAL_LIBGSASL_LIBRARY AND NOT APPLE AND NOT ARCH_32)
elseif (NOT MISSING_INTERNAL_LIBGSASL_LIBRARY AND NOT ARCH_32)
set (LIBGSASL_INCLUDE_DIR ${ClickHouse_SOURCE_DIR}/contrib/libgsasl/src ${ClickHouse_SOURCE_DIR}/contrib/libgsasl/linux_x86_64/include)
set (USE_INTERNAL_LIBGSASL_LIBRARY 1)
set (LIBGSASL_LIBRARY libgsasl)

View File

@ -1,5 +1,5 @@
# Freebsd: contrib/cppkafka/include/cppkafka/detail/endianness.h:53:23: error: 'betoh16' was not declared in this scope
if (NOT ARCH_ARM AND NOT ARCH_32 AND NOT APPLE AND NOT OS_FREEBSD AND OPENSSL_FOUND)
if (NOT ARCH_ARM AND NOT ARCH_32 AND NOT OS_FREEBSD AND OPENSSL_FOUND)
option (ENABLE_RDKAFKA "Enable kafka" ${ENABLE_LIBRARIES})
endif ()
@ -10,7 +10,7 @@ endif ()
if (ENABLE_RDKAFKA)
if (OS_LINUX AND NOT ARCH_ARM AND USE_LIBGSASL)
if (NOT ARCH_ARM AND USE_LIBGSASL)
option (USE_INTERNAL_RDKAFKA_LIBRARY "Set to FALSE to use system librdkafka instead of the bundled" ${NOT_UNBUNDLED})
endif ()

View File

@ -146,3 +146,5 @@ target_compile_definitions(curl PRIVATE HAVE_CONFIG_H BUILDING_LIBCURL CURL_HIDD
target_include_directories(curl PUBLIC ${CURL_DIR}/include ${CURL_DIR}/lib .)
target_compile_definitions(curl PRIVATE OS="${CMAKE_SYSTEM_NAME}")
target_link_libraries(curl PRIVATE ssl)

View File

@ -1,3 +1,4 @@
#define CURL_CA_BUNDLE "/etc/ssl/certs/ca-certificates.crt"
#define CURL_DISABLE_FTP
#define CURL_DISABLE_TFTP
#define CURL_DISABLE_LDAP
@ -9,9 +10,14 @@
#define SIZEOF_CURL_OFF_T 8
#define SIZEOF_SIZE_T 8
#define HAVE_ALARM
#define HAVE_FCNTL_O_NONBLOCK
#define HAVE_GETADDRINFO
#define HAVE_LONGLONG
#define HAVE_POLL_FINE
#define HAVE_SIGACTION
#define HAVE_SIGNAL
#define HAVE_SIGSETJMP
#define HAVE_SOCKET
#define HAVE_STRUCT_TIMEVAL
@ -34,5 +40,11 @@
#define HAVE_ERRNO_H
#define HAVE_FCNTL_H
#define HAVE_NETDB_H
#define HAVE_NETINET_IN_H
#define HAVE_SETJMP_H
#define HAVE_SYS_STAT_H
#define HAVE_UNISTD_H
#define ENABLE_IPV6
#define USE_OPENSSL
#define USE_THREADS_POSIX

2
contrib/libgsasl vendored

@ -1 +1 @@
Subproject commit 3b8948a4042e34fb00b4fb987535dc9e02e39040
Subproject commit 42ef20687042637252e64df1934b6d47771486d1

2
contrib/librdkafka vendored

@ -1 +1 @@
Subproject commit 6160ec275a5bb0a4088ede3c5f2afde638bbef65
Subproject commit 4ffe54b4f59ee5ae3767f9f25dc14651a3384d62

View File

@ -23,6 +23,8 @@ set(SRCS
${RDKAFKA_SOURCE_DIR}/rdkafka_lz4.c
${RDKAFKA_SOURCE_DIR}/rdkafka_metadata.c
${RDKAFKA_SOURCE_DIR}/rdkafka_metadata_cache.c
${RDKAFKA_SOURCE_DIR}/rdkafka_mock.c
${RDKAFKA_SOURCE_DIR}/rdkafka_mock_handlers.c
${RDKAFKA_SOURCE_DIR}/rdkafka_msg.c
${RDKAFKA_SOURCE_DIR}/rdkafka_msgset_reader.c
${RDKAFKA_SOURCE_DIR}/rdkafka_msgset_writer.c

View File

@ -75,8 +75,18 @@
#define HAVE_STRNDUP 1
// strerror_r
#define HAVE_STRERROR_R 1
#ifdef __APPLE__
// pthread_setname_np
#define HAVE_PTHREAD_SETNAME_DARWIN 1
#if (__ENVIRONMENT_MAC_OS_X_VERSION_MIN_REQUIRED__ <= 101400)
#define _TTHREAD_EMULATE_TIMESPEC_GET_
#endif
#else
// pthread_setname_gnu
#define HAVE_PTHREAD_SETNAME_GNU 1
#endif
// python
//#define HAVE_PYTHON 1
// disable C11 threads for compatibility with old libc

2
contrib/poco vendored

@ -1 +1 @@
Subproject commit d478f62bd93c9cd14eb343756ef73a4ae622ddf5
Subproject commit d805cf5ca4cf8bdc642261cfcbe7a0a241cb7298

View File

@ -575,6 +575,7 @@ void HTTPHandler::processQuery(
try
{
char b;
//FIXME looks like MSG_DONTWAIT is useless because of POCO_BROKEN_TIMEOUTS
int status = socket.receiveBytes(&b, 1, MSG_DONTWAIT | MSG_PEEK);
if (status == 0)
context.killCurrentQuery();

View File

@ -218,7 +218,7 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl
try
{
// For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used.
auto user = connection_context.getUser(user_name);
auto user = connection_context.getAccessControlManager().getUser(user_name);
const DB::Authentication::Type user_auth_type = user->authentication.getType();
if (user_auth_type != DB::Authentication::DOUBLE_SHA1_PASSWORD && user_auth_type != DB::Authentication::PLAINTEXT_PASSWORD && user_auth_type != DB::Authentication::NO_PASSWORD)
{

View File

@ -2,6 +2,7 @@
#include <Common/config.h>
#include <Poco/Net/TCPServerConnection.h>
#include <Common/getFQDNOrHostName.h>
#include <Common/CurrentMetrics.h>
#include <Core/MySQLProtocol.h>
#include "IServer.h"
@ -9,6 +10,11 @@
#include <Poco/Net/SecureStreamSocket.h>
#endif
namespace CurrentMetrics
{
extern const Metric MySQLConnection;
}
namespace DB
{
/// Handler for MySQL wire protocol connections. Allows to connect to ClickHouse using MySQL client.
@ -20,6 +26,8 @@ public:
void run() final;
private:
CurrentMetrics::Increment metric_increment{CurrentMetrics::MySQLConnection};
/// Enables SSL, if client requested.
void finishHandshake(MySQLProtocol::HandshakeResponse &);

View File

@ -900,6 +900,10 @@ void TCPHandler::receiveQuery()
client_info.initial_query_id = client_info.current_query_id;
client_info.initial_address = client_info.current_address;
}
else
{
query_context->switchRowPolicy();
}
}
/// Per query settings.

View File

@ -185,7 +185,7 @@
<mlock_executable>false</mlock_executable>
<!-- Configuration of clusters that could be used in Distributed tables.
https://clickhouse.yandex/docs/en/table_engines/distributed/
https://clickhouse.tech/docs/en/operations/table_engines/distributed/
-->
<remote_servers incl="clickhouse_remote_servers" >
<!-- Test only shard config for testing distributed storage -->

View File

@ -35,45 +35,49 @@ AccessControlManager::~AccessControlManager()
}
UserPtr AccessControlManager::getUser(const String & user_name) const
UserPtr AccessControlManager::getUser(
const String & user_name, std::function<void(const UserPtr &)> on_change, ext::scope_guard * subscription) const
{
return getUser(user_name, {}, nullptr);
return getUser(getID<User>(user_name), std::move(on_change), subscription);
}
UserPtr AccessControlManager::getUser(
const String & user_name, const std::function<void(const UserPtr &)> & on_change, ext::scope_guard * subscription) const
const UUID & user_id, std::function<void(const UserPtr &)> on_change, ext::scope_guard * subscription) const
{
UUID id = getID<User>(user_name);
if (on_change && subscription)
{
*subscription = subscribeForChanges(id, [on_change](const UUID &, const AccessEntityPtr & user)
*subscription = subscribeForChanges(user_id, [on_change](const UUID &, const AccessEntityPtr & user)
{
if (user)
on_change(typeid_cast<UserPtr>(user));
});
}
return read<User>(id);
return read<User>(user_id);
}
UserPtr AccessControlManager::authorizeAndGetUser(
const String & user_name,
const String & password,
const Poco::Net::IPAddress & address) const
{
return authorizeAndGetUser(user_name, password, address, {}, nullptr);
}
UserPtr AccessControlManager::authorizeAndGetUser(
const String & user_name,
const String & password,
const Poco::Net::IPAddress & address,
const std::function<void(const UserPtr &)> & on_change,
std::function<void(const UserPtr &)> on_change,
ext::scope_guard * subscription) const
{
auto user = getUser(user_name, on_change, subscription);
user->allowed_client_hosts.checkContains(address, user_name);
user->authentication.checkPassword(password, user_name);
return authorizeAndGetUser(getID<User>(user_name), password, address, std::move(on_change), subscription);
}
UserPtr AccessControlManager::authorizeAndGetUser(
const UUID & user_id,
const String & password,
const Poco::Net::IPAddress & address,
std::function<void(const UserPtr &)> on_change,
ext::scope_guard * subscription) const
{
auto user = getUser(user_id, on_change, subscription);
user->allowed_client_hosts.checkContains(address, user->getName());
user->authentication.checkPassword(password, user->getName());
return user;
}
@ -85,9 +89,9 @@ void AccessControlManager::loadFromConfig(const Poco::Util::AbstractConfiguratio
}
std::shared_ptr<const AccessRightsContext> AccessControlManager::getAccessRightsContext(const ClientInfo & client_info, const AccessRights & granted_to_user, const Settings & settings, const String & current_database)
std::shared_ptr<const AccessRightsContext> AccessControlManager::getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database)
{
return std::make_shared<AccessRightsContext>(client_info, granted_to_user, settings, current_database);
return std::make_shared<AccessRightsContext>(user, client_info, settings, current_database);
}

View File

@ -42,12 +42,12 @@ public:
void loadFromConfig(const Poco::Util::AbstractConfiguration & users_config);
UserPtr getUser(const String & user_name) const;
UserPtr getUser(const String & user_name, const std::function<void(const UserPtr &)> & on_change, ext::scope_guard * subscription) const;
UserPtr authorizeAndGetUser(const String & user_name, const String & password, const Poco::Net::IPAddress & address) const;
UserPtr authorizeAndGetUser(const String & user_name, const String & password, const Poco::Net::IPAddress & address, const std::function<void(const UserPtr &)> & on_change, ext::scope_guard * subscription) const;
UserPtr getUser(const String & user_name, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const;
UserPtr getUser(const UUID & user_id, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const;
UserPtr authorizeAndGetUser(const String & user_name, const String & password, const Poco::Net::IPAddress & address, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const;
UserPtr authorizeAndGetUser(const UUID & user_id, const String & password, const Poco::Net::IPAddress & address, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const;
std::shared_ptr<const AccessRightsContext> getAccessRightsContext(const ClientInfo & client_info, const AccessRights & granted_to_user, const Settings & settings, const String & current_database);
std::shared_ptr<const AccessRightsContext> getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database);
std::shared_ptr<QuotaContext>
createQuotaContext(const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key);

View File

@ -1,9 +1,9 @@
#include <Access/AccessRights.h>
#include <Common/Exception.h>
#include <common/logger_useful.h>
#include <boost/range/adaptor/map.hpp>
#include <unordered_map>
namespace DB
{
namespace ErrorCodes
@ -73,6 +73,7 @@ public:
inherited_access = src.inherited_access;
explicit_grants = src.explicit_grants;
partial_revokes = src.partial_revokes;
raw_access = src.raw_access;
access = src.access;
min_access = src.min_access;
max_access = src.max_access;
@ -114,8 +115,12 @@ public:
access_to_grant = grantable;
}
explicit_grants |= access_to_grant - partial_revokes;
partial_revokes -= access_to_grant;
AccessFlags new_explicit_grants = access_to_grant - partial_revokes;
if (level == TABLE_LEVEL)
removeExplicitGrantsRec(new_explicit_grants);
removePartialRevokesRec(access_to_grant);
explicit_grants |= new_explicit_grants;
calculateAllAccessRec(helper);
}
@ -147,16 +152,27 @@ public:
{
if constexpr (mode == NORMAL_REVOKE_MODE)
{
explicit_grants -= access_to_revoke;
if (level == TABLE_LEVEL)
removeExplicitGrantsRec(access_to_revoke);
else
removeExplicitGrants(access_to_revoke);
}
else if constexpr (mode == PARTIAL_REVOKE_MODE)
{
partial_revokes |= access_to_revoke - explicit_grants;
explicit_grants -= access_to_revoke;
AccessFlags new_partial_revokes = access_to_revoke - explicit_grants;
if (level == TABLE_LEVEL)
removeExplicitGrantsRec(access_to_revoke);
else
removeExplicitGrants(access_to_revoke);
removePartialRevokesRec(new_partial_revokes);
partial_revokes |= new_partial_revokes;
}
else /// mode == FULL_REVOKE_MODE
{
fullRevokeRec(access_to_revoke);
AccessFlags new_partial_revokes = access_to_revoke - explicit_grants;
removeExplicitGrantsRec(access_to_revoke);
removePartialRevokesRec(new_partial_revokes);
partial_revokes |= new_partial_revokes;
}
calculateAllAccessRec(helper);
}
@ -272,6 +288,24 @@ public:
calculateAllAccessRec(helper);
}
void traceTree(Poco::Logger * log) const
{
LOG_TRACE(log, "Tree(" << level << "): name=" << (node_name ? *node_name : "NULL")
<< ", explicit_grants=" << explicit_grants.toString()
<< ", partial_revokes=" << partial_revokes.toString()
<< ", inherited_access=" << inherited_access.toString()
<< ", raw_access=" << raw_access.toString()
<< ", access=" << access.toString()
<< ", min_access=" << min_access.toString()
<< ", max_access=" << max_access.toString()
<< ", num_children=" << (children ? children->size() : 0));
if (children)
{
for (auto & child : *children | boost::adaptors::map_values)
child.traceTree(log);
}
}
private:
Node * tryGetChild(const std::string_view & name)
{
@ -371,14 +405,28 @@ private:
calculateMinAndMaxAccess();
}
void fullRevokeRec(const AccessFlags & access_to_revoke)
void removeExplicitGrants(const AccessFlags & change)
{
explicit_grants -= access_to_revoke;
partial_revokes |= access_to_revoke;
explicit_grants -= change;
}
void removeExplicitGrantsRec(const AccessFlags & change)
{
removeExplicitGrants(change);
if (children)
{
for (auto & child : *children | boost::adaptors::map_values)
child.fullRevokeRec(access_to_revoke);
child.removeExplicitGrantsRec(change);
}
}
void removePartialRevokesRec(const AccessFlags & change)
{
partial_revokes -= change;
if (children)
{
for (auto & child : *children | boost::adaptors::map_values)
child.removePartialRevokesRec(change);
}
}
@ -726,4 +774,13 @@ void AccessRights::merge(const AccessRights & other)
}
}
void AccessRights::traceTree() const
{
auto * log = &Poco::Logger::get("AccessRights");
if (root)
root->traceTree(log);
else
LOG_TRACE(log, "Tree: NULL");
}
}

View File

@ -130,6 +130,8 @@ private:
template <typename... Args>
AccessFlags getAccessImpl(const Args &... args) const;
void traceTree() const;
struct Node;
std::unique_ptr<Node> root;
};

View File

@ -1,4 +1,5 @@
#include <Access/AccessRightsContext.h>
#include <Access/User.h>
#include <Common/Exception.h>
#include <Common/quoteString.h>
#include <Core/Settings.h>
@ -88,24 +89,23 @@ AccessRightsContext::AccessRightsContext()
}
AccessRightsContext::AccessRightsContext(const ClientInfo & client_info_, const AccessRights & granted_to_user_, const Settings & settings, const String & current_database_)
: user_name(client_info_.current_user)
, granted_to_user(granted_to_user_)
AccessRightsContext::AccessRightsContext(const UserPtr & user_, const ClientInfo & client_info_, const Settings & settings, const String & current_database_)
: user(user_)
, readonly(settings.readonly)
, allow_ddl(settings.allow_ddl)
, allow_introspection(settings.allow_introspection_functions)
, current_database(current_database_)
, interface(client_info_.interface)
, http_method(client_info_.http_method)
, trace_log(&Poco::Logger::get("AccessRightsContext (" + user_name + ")"))
, trace_log(&Poco::Logger::get("AccessRightsContext (" + user_->getName() + ")"))
{
}
template <int mode, typename... Args>
template <int mode, bool grant_option, typename... Args>
bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const
{
auto result_access = calculateResultAccess();
auto result_access = calculateResultAccess(grant_option);
bool is_granted = result_access->isGranted(access, args...);
if (trace_log)
@ -126,12 +126,21 @@ bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & acc
auto show_error = [&](const String & msg, [[maybe_unused]] int error_code)
{
if constexpr (mode == THROW_IF_ACCESS_DENIED)
throw Exception(msg, error_code);
throw Exception(user->getName() + ": " + msg, error_code);
else if constexpr (mode == LOG_WARNING_IF_ACCESS_DENIED)
LOG_WARNING(log_, msg + formatSkippedMessage(args...));
LOG_WARNING(log_, user->getName() + ": " + msg + formatSkippedMessage(args...));
};
if (readonly && calculateResultAccess(false, allow_ddl, allow_introspection)->isGranted(access, args...))
if (grant_option && calculateResultAccess(false, readonly, allow_ddl, allow_introspection)->isGranted(access, args...))
{
show_error(
"Not enough privileges. "
"The required privileges have been granted, but without grant option. "
"To execute this query it's necessary to have the grant "
+ AccessRightsElement{access, args...}.toString() + " WITH GRANT OPTION",
ErrorCodes::ACCESS_DENIED);
}
else if (readonly && calculateResultAccess(false, false, allow_ddl, allow_introspection)->isGranted(access, args...))
{
if (interface == ClientInfo::Interface::HTTP && http_method == ClientInfo::HTTPMethod::GET)
show_error(
@ -141,108 +150,116 @@ bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & acc
else
show_error("Cannot execute query in readonly mode", ErrorCodes::READONLY);
}
else if (!allow_ddl && calculateResultAccess(readonly, true, allow_introspection)->isGranted(access, args...))
else if (!allow_ddl && calculateResultAccess(false, readonly, true, allow_introspection)->isGranted(access, args...))
{
show_error("Cannot execute query. DDL queries are prohibited for the user", ErrorCodes::QUERY_IS_PROHIBITED);
}
else if (!allow_introspection && calculateResultAccess(readonly, allow_ddl, true)->isGranted(access, args...))
else if (!allow_introspection && calculateResultAccess(false, readonly, allow_ddl, true)->isGranted(access, args...))
{
show_error("Introspection functions are disabled, because setting 'allow_introspection_functions' is set to 0", ErrorCodes::FUNCTION_NOT_ALLOWED);
}
else
{
show_error(
user_name + ": Not enough privileges. To perform this operation you should have grant "
+ AccessRightsElement{access, args...}.toString(),
"Not enough privileges. To execute this query it's necessary to have the grant "
+ AccessRightsElement{access, args...}.toString() + (grant_option ? " WITH GRANT OPTION" : ""),
ErrorCodes::ACCESS_DENIED);
}
return false;
}
template <int mode>
template <int mode, bool grant_option>
bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessRightsElement & element) const
{
if (element.any_database)
{
return checkImpl<mode>(log_, element.access_flags);
return checkImpl<mode, grant_option>(log_, element.access_flags);
}
else if (element.any_table)
{
if (element.database.empty())
return checkImpl<mode>(log_, element.access_flags, current_database);
return checkImpl<mode, grant_option>(log_, element.access_flags, current_database);
else
return checkImpl<mode>(log_, element.access_flags, element.database);
return checkImpl<mode, grant_option>(log_, element.access_flags, element.database);
}
else if (element.any_column)
{
if (element.database.empty())
return checkImpl<mode>(log_, element.access_flags, current_database, element.table);
return checkImpl<mode, grant_option>(log_, element.access_flags, current_database, element.table);
else
return checkImpl<mode>(log_, element.access_flags, element.database, element.table);
return checkImpl<mode, grant_option>(log_, element.access_flags, element.database, element.table);
}
else
{
if (element.database.empty())
return checkImpl<mode>(log_, element.access_flags, current_database, element.table, element.columns);
return checkImpl<mode, grant_option>(log_, element.access_flags, current_database, element.table, element.columns);
else
return checkImpl<mode>(log_, element.access_flags, element.database, element.table, element.columns);
return checkImpl<mode, grant_option>(log_, element.access_flags, element.database, element.table, element.columns);
}
}
template <int mode>
template <int mode, bool grant_option>
bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessRightsElements & elements) const
{
for (const auto & element : elements)
if (!checkImpl<mode>(log_, element))
if (!checkImpl<mode, grant_option>(log_, element))
return false;
return true;
}
void AccessRightsContext::check(const AccessFlags & access) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database, table); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database, table, column); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database, table, columns); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database, table, columns); }
void AccessRightsContext::check(const AccessRightsElement & access) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access); }
void AccessRightsContext::check(const AccessRightsElements & access) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access); }
void AccessRightsContext::check(const AccessFlags & access) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, column); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
void AccessRightsContext::check(const AccessRightsElement & access) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); }
void AccessRightsContext::check(const AccessRightsElements & access) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessFlags & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database, table); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database, table, column); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database, table, columns); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database, table, columns); }
bool AccessRightsContext::isGranted(const AccessRightsElement & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessRightsElements & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessFlags & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, column); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
bool AccessRightsContext::isGranted(const AccessRightsElement & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessRightsElements & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database, table); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database, table, column); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database, table, columns); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database, table, columns); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElement & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElements & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, column); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, columns); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, columns); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElement & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElements & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, column); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, columns); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, columns); }
void AccessRightsContext::checkGrantOption(const AccessRightsElement & access) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); }
void AccessRightsContext::checkGrantOption(const AccessRightsElements & access) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); }
boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess() const
boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess(bool grant_option) const
{
auto res = result_access_cache[0].load();
if (res)
return res;
return calculateResultAccess(readonly, allow_ddl, allow_introspection);
return calculateResultAccess(grant_option, readonly, allow_ddl, allow_introspection);
}
boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess(UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const
boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const
{
size_t cache_index = static_cast<size_t>(readonly_ != readonly)
+ static_cast<size_t>(allow_ddl_ != allow_ddl) * 2 +
+ static_cast<size_t>(allow_introspection_ != allow_introspection) * 3;
+ static_cast<size_t>(allow_introspection_ != allow_introspection) * 3
+ static_cast<size_t>(grant_option) * 4;
assert(cache_index < std::size(result_access_cache));
auto cached = result_access_cache[cache_index].load();
if (cached)
@ -256,7 +273,7 @@ boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess
auto result_ptr = boost::make_shared<AccessRights>();
auto & result = *result_ptr;
result = granted_to_user;
result = grant_option ? user->access_with_grant_option : user->access;
static const AccessFlags table_ddl = AccessType::CREATE_DATABASE | AccessType::CREATE_TABLE | AccessType::CREATE_VIEW
| AccessType::ALTER_TABLE | AccessType::ALTER_VIEW | AccessType::DROP_DATABASE | AccessType::DROP_TABLE | AccessType::DROP_VIEW
@ -265,12 +282,18 @@ boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess
static const AccessFlags table_and_dictionary_ddl = table_ddl | dictionary_ddl;
static const AccessFlags write_table_access = AccessType::INSERT | AccessType::OPTIMIZE;
/// Anyone has access to the "system" database.
result.grant(AccessType::SELECT, "system");
if (readonly_)
result.fullRevoke(write_table_access | AccessType::SYSTEM);
if (readonly_ || !allow_ddl_)
result.fullRevoke(table_and_dictionary_ddl);
if (readonly_ && grant_option)
result.fullRevoke(AccessType::ALL);
if (readonly_ == 1)
{
/// Table functions are forbidden in readonly mode.
@ -282,7 +305,11 @@ boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess
result.fullRevoke(AccessType::INTROSPECTION);
result_access_cache[cache_index].store(result_ptr);
return std::move(result_ptr);
if (trace_log && (readonly == readonly_) && (allow_ddl == allow_ddl_) && (allow_introspection == allow_introspection_))
LOG_TRACE(trace_log, "List of all grants: " << result_ptr->toString() << (grant_option ? " WITH GRANT OPTION" : ""));
return result_ptr;
}
}

View File

@ -11,6 +11,8 @@ namespace Poco { class Logger; }
namespace DB
{
struct Settings;
struct User;
using UserPtr = std::shared_ptr<const User>;
class AccessRightsContext
@ -19,7 +21,7 @@ public:
/// Default constructor creates access rights' context which allows everything.
AccessRightsContext();
AccessRightsContext(const ClientInfo & client_info_, const AccessRights & granted_to_user, const Settings & settings, const String & current_database_);
AccessRightsContext(const UserPtr & user_, const ClientInfo & client_info_, const Settings & settings, const String & current_database_);
/// Checks if a specified access granted, and throws an exception if not.
/// Empty database means the current database.
@ -52,21 +54,30 @@ public:
bool isGranted(Poco::Logger * log_, const AccessRightsElement & access) const;
bool isGranted(Poco::Logger * log_, const AccessRightsElements & access) const;
/// Checks if a specified access granted with grant option, and throws an exception if not.
void checkGrantOption(const AccessFlags & access) const;
void checkGrantOption(const AccessFlags & access, const std::string_view & database) const;
void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const;
void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const;
void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const;
void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const;
void checkGrantOption(const AccessRightsElement & access) const;
void checkGrantOption(const AccessRightsElements & access) const;
private:
template <int mode, typename... Args>
template <int mode, bool grant_option, typename... Args>
bool checkImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const;
template <int mode>
template <int mode, bool grant_option>
bool checkImpl(Poco::Logger * log_, const AccessRightsElement & access) const;
template <int mode>
template <int mode, bool grant_option>
bool checkImpl(Poco::Logger * log_, const AccessRightsElements & access) const;
boost::shared_ptr<const AccessRights> calculateResultAccess() const;
boost::shared_ptr<const AccessRights> calculateResultAccess(UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const;
boost::shared_ptr<const AccessRights> calculateResultAccess(bool grant_option) const;
boost::shared_ptr<const AccessRights> calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const;
const String user_name;
const AccessRights granted_to_user;
const UserPtr user;
const UInt64 readonly = 0;
const bool allow_ddl = true;
const bool allow_introspection = true;
@ -74,7 +85,7 @@ private:
const ClientInfo::Interface interface = ClientInfo::Interface::TCP;
const ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN;
Poco::Logger * const trace_log = nullptr;
mutable boost::atomic_shared_ptr<const AccessRights> result_access_cache[4];
mutable boost::atomic_shared_ptr<const AccessRights> result_access_cache[7];
mutable std::mutex mutex;
};

View File

@ -1,15 +1,12 @@
#include <Access/AllowedClientHosts.h>
#include <Common/Exception.h>
#include <common/SimpleCache.h>
#include <Common/StringUtils/StringUtils.h>
#include <IO/ReadHelpers.h>
#include <Functions/likePatternToRegexp.h>
#include <Poco/Net/SocketAddress.h>
#include <Poco/RegularExpression.h>
#include <common/logger_useful.h>
#include <ext/scope_guard.h>
#include <boost/range/algorithm/find.hpp>
#include <boost/range/algorithm/find_first_of.hpp>
#include <boost/algorithm/string/predicate.hpp>
#include <boost/algorithm/string/replace.hpp>
#include <ifaddrs.h>
@ -27,20 +24,6 @@ namespace
using IPSubnet = AllowedClientHosts::IPSubnet;
const IPSubnet ALL_ADDRESSES{IPAddress{IPAddress::IPv6}, IPAddress{IPAddress::IPv6}};
const IPAddress & getIPV6Loopback()
{
static const IPAddress ip("::1");
return ip;
}
bool isIPV4LoopbackMappedToIPV6(const IPAddress & ip)
{
static const IPAddress prefix("::ffff:127.0.0.0");
/// 104 == 128 - 24, we have to reset the lowest 24 bits of 128 before comparing with `prefix`
/// (IPv4 loopback means any IP from 127.0.0.0 to 127.255.255.255).
return (ip & IPAddress(104, IPAddress::IPv6)) == prefix;
}
/// Converts an address to IPv6.
/// The loopback address "127.0.0.1" (or any "127.x.y.z") is converted to "::1".
IPAddress toIPv6(const IPAddress & ip)
@ -52,35 +35,18 @@ namespace
v6 = IPAddress("::ffff:" + ip.toString());
// ::ffff:127.XX.XX.XX -> ::1
if (isIPV4LoopbackMappedToIPV6(v6))
v6 = getIPV6Loopback();
if ((v6 & IPAddress(104, IPAddress::IPv6)) == IPAddress("::ffff:127.0.0.0"))
v6 = IPAddress{"::1"};
return v6;
}
/// Converts a subnet to IPv6.
IPSubnet toIPv6(const IPSubnet & subnet)
{
IPSubnet v6;
if (subnet.prefix.family() == IPAddress::IPv6)
v6.prefix = subnet.prefix;
else
v6.prefix = IPAddress("::ffff:" + subnet.prefix.toString());
if (subnet.mask.family() == IPAddress::IPv6)
v6.mask = subnet.mask;
else
v6.mask = IPAddress(96, IPAddress::IPv6) | IPAddress("::ffff:" + subnet.mask.toString());
v6.prefix = v6.prefix & v6.mask;
// ::ffff:127.XX.XX.XX -> ::1
if (isIPV4LoopbackMappedToIPV6(v6.prefix))
v6 = {getIPV6Loopback(), IPAddress(128, IPAddress::IPv6)};
return v6;
return IPSubnet(toIPv6(subnet.getPrefix()), subnet.getMask());
}
/// Helper function for isAddressOfHost().
bool isAddressOfHostImpl(const IPAddress & address, const String & host)
{
@ -150,7 +116,7 @@ namespace
int err = getifaddrs(&ifa_begin);
if (err)
return {getIPV6Loopback()};
return {IPAddress{"::1"}};
for (const ifaddrs * ifa = ifa_begin; ifa; ifa = ifa->ifa_next)
{
@ -203,163 +169,203 @@ namespace
static SimpleCache<decltype(getHostByAddressImpl), &getHostByAddressImpl> cache;
return cache(address);
}
}
String AllowedClientHosts::IPSubnet::toString() const
{
unsigned int prefix_length = mask.prefixLength();
if (IPAddress{prefix_length, mask.family()} == mask)
return prefix.toString() + "/" + std::to_string(prefix_length);
return prefix.toString() + "/" + mask.toString();
}
AllowedClientHosts::AllowedClientHosts()
{
}
AllowedClientHosts::AllowedClientHosts(AllAddressesTag)
{
addAllAddresses();
}
AllowedClientHosts::~AllowedClientHosts() = default;
AllowedClientHosts::AllowedClientHosts(const AllowedClientHosts & src)
{
*this = src;
}
AllowedClientHosts & AllowedClientHosts::operator =(const AllowedClientHosts & src)
{
addresses = src.addresses;
localhost = src.localhost;
subnets = src.subnets;
host_names = src.host_names;
host_regexps = src.host_regexps;
compiled_host_regexps.clear();
return *this;
}
AllowedClientHosts::AllowedClientHosts(AllowedClientHosts && src) = default;
AllowedClientHosts & AllowedClientHosts::operator =(AllowedClientHosts && src) = default;
void AllowedClientHosts::clear()
{
addresses.clear();
localhost = false;
subnets.clear();
host_names.clear();
host_regexps.clear();
compiled_host_regexps.clear();
}
bool AllowedClientHosts::empty() const
{
return addresses.empty() && subnets.empty() && host_names.empty() && host_regexps.empty();
}
void AllowedClientHosts::addAddress(const IPAddress & address)
{
IPAddress addr_v6 = toIPv6(address);
if (boost::range::find(addresses, addr_v6) != addresses.end())
return;
addresses.push_back(addr_v6);
if (addr_v6.isLoopback())
localhost = true;
}
void AllowedClientHosts::addAddress(const String & address)
{
addAddress(IPAddress{address});
}
void AllowedClientHosts::addSubnet(const IPSubnet & subnet)
{
IPSubnet subnet_v6 = toIPv6(subnet);
if (subnet_v6.mask == IPAddress(128, IPAddress::IPv6))
void parseLikePatternIfIPSubnet(const String & pattern, IPSubnet & subnet, IPAddress::Family address_family)
{
addAddress(subnet_v6.prefix);
return;
size_t slash = pattern.find('/');
if (slash != String::npos)
{
/// IP subnet, e.g. "192.168.0.0/16" or "192.168.0.0/255.255.0.0".
subnet = IPSubnet{pattern};
return;
}
bool has_wildcard = (pattern.find_first_of("%_") != String::npos);
if (has_wildcard)
{
/// IP subnet specified with one of the wildcard characters, e.g. "192.168.%.%".
String wildcard_replaced_with_zero_bits = pattern;
String wildcard_replaced_with_one_bits = pattern;
if (address_family == IPAddress::IPv6)
{
boost::algorithm::replace_all(wildcard_replaced_with_zero_bits, "_", "0");
boost::algorithm::replace_all(wildcard_replaced_with_zero_bits, "%", "0000");
boost::algorithm::replace_all(wildcard_replaced_with_one_bits, "_", "f");
boost::algorithm::replace_all(wildcard_replaced_with_one_bits, "%", "ffff");
}
else if (address_family == IPAddress::IPv4)
{
boost::algorithm::replace_all(wildcard_replaced_with_zero_bits, "%", "0");
boost::algorithm::replace_all(wildcard_replaced_with_one_bits, "%", "255");
}
IPAddress prefix{wildcard_replaced_with_zero_bits};
IPAddress mask = ~(prefix ^ IPAddress{wildcard_replaced_with_one_bits});
subnet = IPSubnet{prefix, mask};
return;
}
/// Exact IP address.
subnet = IPSubnet{pattern};
}
if (boost::range::find(subnets, subnet_v6) == subnets.end())
subnets.push_back(subnet_v6);
}
void AllowedClientHosts::addSubnet(const IPAddress & prefix, const IPAddress & mask)
{
addSubnet(IPSubnet{prefix, mask});
}
void AllowedClientHosts::addSubnet(const IPAddress & prefix, size_t num_prefix_bits)
{
addSubnet(prefix, IPAddress(num_prefix_bits, prefix.family()));
}
void AllowedClientHosts::addSubnet(const String & subnet)
{
size_t slash = subnet.find('/');
if (slash == String::npos)
/// Extracts a subnet, a host name or a host name regular expession from a like pattern.
void parseLikePattern(
const String & pattern, std::optional<IPSubnet> & subnet, std::optional<String> & name, std::optional<String> & name_regexp)
{
addAddress(subnet);
return;
/// If `host` starts with digits and a dot then it's an IP pattern, otherwise it's a hostname pattern.
size_t first_not_digit = pattern.find_first_not_of("0123456789");
if ((first_not_digit != String::npos) && (first_not_digit != 0) && (pattern[first_not_digit] == '.'))
{
parseLikePatternIfIPSubnet(pattern, subnet.emplace(), IPAddress::IPv4);
return;
}
size_t first_not_hex = pattern.find_first_not_of("0123456789ABCDEFabcdef");
if (((first_not_hex == 4) && pattern[first_not_hex] == ':') || pattern.starts_with("::"))
{
parseLikePatternIfIPSubnet(pattern, subnet.emplace(), IPAddress::IPv6);
return;
}
bool has_wildcard = (pattern.find_first_of("%_") != String::npos);
if (has_wildcard)
{
name_regexp = likePatternToRegexp(pattern);
return;
}
name = pattern;
}
IPAddress prefix{String{subnet, 0, slash}};
String mask(subnet, slash + 1, subnet.length() - slash - 1);
if (std::all_of(mask.begin(), mask.end(), isNumericASCII))
addSubnet(prefix, parseFromString<UInt8>(mask));
else
addSubnet(prefix, IPAddress{mask});
}
void AllowedClientHosts::addHostName(const String & host_name)
bool AllowedClientHosts::contains(const IPAddress & client_address) const
{
if (boost::range::find(host_names, host_name) != host_names.end())
return;
host_names.push_back(host_name);
if (boost::iequals(host_name, "localhost"))
localhost = true;
}
if (any_host)
return true;
IPAddress client_v6 = toIPv6(client_address);
void AllowedClientHosts::addHostRegexp(const String & host_regexp)
{
if (boost::range::find(host_regexps, host_regexp) == host_regexps.end())
host_regexps.push_back(host_regexp);
}
std::optional<bool> is_client_local_value;
auto is_client_local = [&]
{
if (is_client_local_value)
return *is_client_local_value;
is_client_local_value = isAddressOfLocalhost(client_v6);
return *is_client_local_value;
};
if (local_host && is_client_local())
return true;
void AllowedClientHosts::addAllAddresses()
{
clear();
addSubnet(ALL_ADDRESSES);
}
/// Check `addresses`.
auto check_address = [&](const IPAddress & address_)
{
IPAddress address_v6 = toIPv6(address_);
if (address_v6.isLoopback())
return is_client_local();
return address_v6 == client_v6;
};
for (const auto & address : addresses)
if (check_address(address))
return true;
bool AllowedClientHosts::containsAllAddresses() const
{
return (boost::range::find(subnets, ALL_ADDRESSES) != subnets.end())
|| (boost::range::find(host_regexps, ".*") != host_regexps.end())
|| (boost::range::find(host_regexps, "$") != host_regexps.end());
/// Check `subnets`.
auto check_subnet = [&](const IPSubnet & subnet_)
{
IPSubnet subnet_v6 = toIPv6(subnet_);
if (subnet_v6.isMaskAllBitsOne())
return check_address(subnet_v6.getPrefix());
return (client_v6 & subnet_v6.getMask()) == subnet_v6.getPrefix();
};
for (const auto & subnet : subnets)
if (check_subnet(subnet))
return true;
/// Check `names`.
auto check_name = [&](const String & name_)
{
if (boost::iequals(name_, "localhost"))
return is_client_local();
try
{
return isAddressOfHost(client_v6, name_);
}
catch (const Exception & e)
{
if (e.code() != ErrorCodes::DNS_ERROR)
throw;
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
LOG_WARNING(
&Logger::get("AddressPatterns"),
"Failed to check if the allowed client hosts contain address " << client_address.toString() << ". " << e.displayText()
<< ", code = " << e.code());
return false;
}
};
for (const String & name : names)
if (check_name(name))
return true;
/// Check `name_regexps`.
std::optional<String> resolved_host;
auto check_name_regexp = [&](const String & name_regexp_)
{
try
{
if (boost::iequals(name_regexp_, "localhost"))
return is_client_local();
if (!resolved_host)
resolved_host = getHostByAddress(client_v6);
if (resolved_host->empty())
return false;
Poco::RegularExpression re(name_regexp_);
Poco::RegularExpression::Match match;
return re.match(*resolved_host, match) != 0;
}
catch (const Exception & e)
{
if (e.code() != ErrorCodes::DNS_ERROR)
throw;
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
LOG_WARNING(
&Logger::get("AddressPatterns"),
"Failed to check if the allowed client hosts contain address " << client_address.toString() << ". " << e.displayText()
<< ", code = " << e.code());
return false;
}
};
for (const String & name_regexp : name_regexps)
if (check_name_regexp(name_regexp))
return true;
auto check_like_pattern = [&](const String & pattern)
{
std::optional<IPSubnet> subnet;
std::optional<String> name;
std::optional<String> name_regexp;
parseLikePattern(pattern, subnet, name, name_regexp);
if (subnet)
return check_subnet(*subnet);
else if (name)
return check_name(*name);
else if (name_regexp)
return check_name_regexp(*name_regexp);
else
return false;
};
for (const String & like_pattern : like_patterns)
if (check_like_pattern(like_pattern))
return true;
return false;
}
@ -374,86 +380,4 @@ void AllowedClientHosts::checkContains(const IPAddress & address, const String &
}
}
bool AllowedClientHosts::contains(const IPAddress & address) const
{
/// Check `ip_addresses`.
IPAddress addr_v6 = toIPv6(address);
if (boost::range::find(addresses, addr_v6) != addresses.end())
return true;
if (localhost && isAddressOfLocalhost(addr_v6))
return true;
/// Check `ip_subnets`.
for (const auto & subnet : subnets)
if ((addr_v6 & subnet.mask) == subnet.prefix)
return true;
/// Check `hosts`.
for (const String & host_name : host_names)
{
try
{
if (isAddressOfHost(addr_v6, host_name))
return true;
}
catch (const Exception & e)
{
if (e.code() != ErrorCodes::DNS_ERROR)
throw;
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
LOG_WARNING(
&Logger::get("AddressPatterns"),
"Failed to check if the allowed client hosts contain address " << address.toString() << ". " << e.displayText()
<< ", code = " << e.code());
}
}
/// Check `host_regexps`.
try
{
String resolved_host = getHostByAddress(addr_v6);
if (!resolved_host.empty())
{
compileRegexps();
for (const auto & compiled_regexp : compiled_host_regexps)
{
Poco::RegularExpression::Match match;
if (compiled_regexp && compiled_regexp->match(resolved_host, match))
return true;
}
}
}
catch (const Exception & e)
{
if (e.code() != ErrorCodes::DNS_ERROR)
throw;
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
LOG_WARNING(
&Logger::get("AddressPatterns"),
"Failed to check if the allowed client hosts contain address " << address.toString() << ". " << e.displayText()
<< ", code = " << e.code());
}
return false;
}
void AllowedClientHosts::compileRegexps() const
{
if (compiled_host_regexps.size() == host_regexps.size())
return;
size_t old_size = compiled_host_regexps.size();
compiled_host_regexps.reserve(host_regexps.size());
for (size_t i = old_size; i != host_regexps.size(); ++i)
compiled_host_regexps.emplace_back(std::make_unique<Poco::RegularExpression>(host_regexps[i]));
}
bool operator ==(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs)
{
return (lhs.addresses == rhs.addresses) && (lhs.subnets == rhs.subnets) && (lhs.host_names == rhs.host_names)
&& (lhs.host_regexps == rhs.host_regexps);
}
}

View File

@ -4,12 +4,9 @@
#include <Poco/Net/IPAddress.h>
#include <memory>
#include <vector>
namespace Poco
{
class RegularExpression;
}
#include <boost/range/algorithm/find.hpp>
#include <boost/range/algorithm_ext/erase.hpp>
#include <boost/algorithm/string/predicate.hpp>
namespace DB
@ -20,69 +17,100 @@ class AllowedClientHosts
public:
using IPAddress = Poco::Net::IPAddress;
struct IPSubnet
class IPSubnet
{
IPAddress prefix;
IPAddress mask;
public:
IPSubnet() {}
IPSubnet(const IPAddress & prefix_, const IPAddress & mask_) { set(prefix_, mask_); }
IPSubnet(const IPAddress & prefix_, size_t num_prefix_bits) { set(prefix_, num_prefix_bits); }
explicit IPSubnet(const IPAddress & address) { set(address); }
explicit IPSubnet(const String & str);
const IPAddress & getPrefix() const { return prefix; }
const IPAddress & getMask() const { return mask; }
bool isMaskAllBitsOne() const;
String toString() const;
friend bool operator ==(const IPSubnet & lhs, const IPSubnet & rhs) { return (lhs.prefix == rhs.prefix) && (lhs.mask == rhs.mask); }
friend bool operator !=(const IPSubnet & lhs, const IPSubnet & rhs) { return !(lhs == rhs); }
private:
void set(const IPAddress & prefix_, const IPAddress & mask_);
void set(const IPAddress & prefix_, size_t num_prefix_bits);
void set(const IPAddress & address);
IPAddress prefix;
IPAddress mask;
};
struct AllAddressesTag {};
struct AnyHostTag {};
AllowedClientHosts();
explicit AllowedClientHosts(AllAddressesTag);
~AllowedClientHosts();
AllowedClientHosts() {}
explicit AllowedClientHosts(AnyHostTag) { addAnyHost(); }
~AllowedClientHosts() {}
AllowedClientHosts(const AllowedClientHosts & src);
AllowedClientHosts & operator =(const AllowedClientHosts & src);
AllowedClientHosts(AllowedClientHosts && src);
AllowedClientHosts & operator =(AllowedClientHosts && src);
AllowedClientHosts(const AllowedClientHosts & src) = default;
AllowedClientHosts & operator =(const AllowedClientHosts & src) = default;
AllowedClientHosts(AllowedClientHosts && src) = default;
AllowedClientHosts & operator =(AllowedClientHosts && src) = default;
/// Removes all contained addresses. This will disallow all addresses.
/// Removes all contained addresses. This will disallow all hosts.
void clear();
bool empty() const;
/// Allows exact IP address.
/// For example, 213.180.204.3 or 2a02:6b8::3
void addAddress(const IPAddress & address);
void addAddress(const String & address);
/// Allows an IP subnet.
void addSubnet(const IPSubnet & subnet);
void addSubnet(const String & subnet);
/// Allows an IP subnet.
/// For example, 312.234.1.1/255.255.255.0 or 2a02:6b8::3/FFFF:FFFF:FFFF:FFFF::
void addSubnet(const IPAddress & prefix, const IPAddress & mask);
/// Allows an IP subnet.
/// For example, 10.0.0.1/8 or 2a02:6b8::3/64
void addSubnet(const IPAddress & prefix, size_t num_prefix_bits);
/// Allows all addresses.
void addAllAddresses();
/// Allows an exact host. The `contains()` function will check that the provided address equals to one of that host's addresses.
void addHostName(const String & host_name);
/// Allows a regular expression for the host.
void addHostRegexp(const String & host_regexp);
void addAddress(const String & address) { addAddress(IPAddress(address)); }
void removeAddress(const IPAddress & address);
void removeAddress(const String & address) { removeAddress(IPAddress{address}); }
const std::vector<IPAddress> & getAddresses() const { return addresses; }
/// Allows an IP subnet.
/// For example, 312.234.1.1/255.255.255.0 or 2a02:6b8::3/64
void addSubnet(const IPSubnet & subnet);
void addSubnet(const String & subnet) { addSubnet(IPSubnet{subnet}); }
void addSubnet(const IPAddress & prefix, const IPAddress & mask) { addSubnet({prefix, mask}); }
void addSubnet(const IPAddress & prefix, size_t num_prefix_bits) { addSubnet({prefix, num_prefix_bits}); }
void removeSubnet(const IPSubnet & subnet);
void removeSubnet(const String & subnet) { removeSubnet(IPSubnet{subnet}); }
void removeSubnet(const IPAddress & prefix, const IPAddress & mask) { removeSubnet({prefix, mask}); }
void removeSubnet(const IPAddress & prefix, size_t num_prefix_bits) { removeSubnet({prefix, num_prefix_bits}); }
const std::vector<IPSubnet> & getSubnets() const { return subnets; }
const std::vector<String> & getHostNames() const { return host_names; }
const std::vector<String> & getHostRegexps() const { return host_regexps; }
/// Allows an exact host name. The `contains()` function will check that the provided address equals to one of that host's addresses.
void addName(const String & name);
void removeName(const String & name);
const std::vector<String> & getNames() const { return names; }
/// Allows the host names matching a regular expression.
void addNameRegexp(const String & name_regexp);
void removeNameRegexp(const String & name_regexp);
const std::vector<String> & getNameRegexps() const { return name_regexps; }
/// Allows IP addresses or host names using LIKE pattern.
/// This pattern can contain % and _ wildcard characters.
/// For example, addLikePattern("@") will allow all addresses.
void addLikePattern(const String & pattern);
void removeLikePattern(const String & like_pattern);
const std::vector<String> & getLikePatterns() const { return like_patterns; }
/// Allows local host.
void addLocalHost();
void removeLocalHost();
bool containsLocalHost() const { return local_host;}
/// Allows any host.
void addAnyHost();
bool containsAnyHost() const { return any_host;}
void add(const AllowedClientHosts & other);
void remove(const AllowedClientHosts & other);
/// Checks if the provided address is in the list. Returns false if not.
bool contains(const IPAddress & address) const;
/// Checks if any address is allowed.
bool containsAllAddresses() const;
/// Checks if the provided address is in the list. Throws an exception if not.
/// `username` is only used for generating an error message if the address isn't in the list.
void checkContains(const IPAddress & address, const String & user_name = String()) const;
@ -91,13 +119,269 @@ public:
friend bool operator !=(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs) { return !(lhs == rhs); }
private:
void compileRegexps() const;
std::vector<IPAddress> addresses;
bool localhost = false;
std::vector<IPSubnet> subnets;
std::vector<String> host_names;
std::vector<String> host_regexps;
mutable std::vector<std::unique_ptr<Poco::RegularExpression>> compiled_host_regexps;
Strings names;
Strings name_regexps;
Strings like_patterns;
bool any_host = false;
bool local_host = false;
};
inline void AllowedClientHosts::IPSubnet::set(const IPAddress & prefix_, const IPAddress & mask_)
{
prefix = prefix_;
mask = mask_;
if (prefix.family() != mask.family())
{
if (prefix.family() == IPAddress::IPv4)
prefix = IPAddress("::ffff:" + prefix.toString());
if (mask.family() == IPAddress::IPv4)
mask = IPAddress(96, IPAddress::IPv6) | IPAddress("::ffff:" + mask.toString());
}
prefix = prefix & mask;
if (prefix.family() == IPAddress::IPv4)
{
if ((prefix & IPAddress{8, IPAddress::IPv4}) == IPAddress{"127.0.0.0"})
{
// 127.XX.XX.XX -> 127.0.0.1
prefix = IPAddress{"127.0.0.1"};
mask = IPAddress{32, IPAddress::IPv4};
}
}
else
{
if ((prefix & IPAddress{104, IPAddress::IPv6}) == IPAddress{"::ffff:127.0.0.0"})
{
// ::ffff:127.XX.XX.XX -> ::1
prefix = IPAddress{"::1"};
mask = IPAddress{128, IPAddress::IPv6};
}
}
}
inline void AllowedClientHosts::IPSubnet::set(const IPAddress & prefix_, size_t num_prefix_bits)
{
set(prefix_, IPAddress(num_prefix_bits, prefix_.family()));
}
inline void AllowedClientHosts::IPSubnet::set(const IPAddress & address)
{
set(address, address.length() * 8);
}
inline AllowedClientHosts::IPSubnet::IPSubnet(const String & str)
{
size_t slash = str.find('/');
if (slash == String::npos)
{
set(IPAddress(str));
return;
}
IPAddress new_prefix{String{str, 0, slash}};
String mask_str(str, slash + 1, str.length() - slash - 1);
bool only_digits = (mask_str.find_first_not_of("0123456789") == std::string::npos);
if (only_digits)
set(new_prefix, std::stoul(mask_str));
else
set(new_prefix, IPAddress{mask_str});
}
inline String AllowedClientHosts::IPSubnet::toString() const
{
unsigned int prefix_length = mask.prefixLength();
if (isMaskAllBitsOne())
return prefix.toString();
else if (IPAddress{prefix_length, mask.family()} == mask)
return prefix.toString() + "/" + std::to_string(prefix_length);
else
return prefix.toString() + "/" + mask.toString();
}
inline bool AllowedClientHosts::IPSubnet::isMaskAllBitsOne() const
{
return mask == IPAddress(mask.length() * 8, mask.family());
}
inline void AllowedClientHosts::clear()
{
addresses = {};
subnets = {};
names = {};
name_regexps = {};
like_patterns = {};
any_host = false;
local_host = false;
}
inline bool AllowedClientHosts::empty() const
{
return !any_host && !local_host && addresses.empty() && subnets.empty() && names.empty() && name_regexps.empty() && like_patterns.empty();
}
inline void AllowedClientHosts::addAddress(const IPAddress & address)
{
if (address.isLoopback())
local_host = true;
else if (boost::range::find(addresses, address) == addresses.end())
addresses.push_back(address);
}
inline void AllowedClientHosts::removeAddress(const IPAddress & address)
{
if (address.isLoopback())
local_host = false;
else
boost::range::remove_erase(addresses, address);
}
inline void AllowedClientHosts::addSubnet(const IPSubnet & subnet)
{
if (subnet.getMask().isWildcard())
any_host = true;
else if (subnet.isMaskAllBitsOne())
addAddress(subnet.getPrefix());
else if (boost::range::find(subnets, subnet) == subnets.end())
subnets.push_back(subnet);
}
inline void AllowedClientHosts::removeSubnet(const IPSubnet & subnet)
{
if (subnet.getMask().isWildcard())
any_host = false;
else if (subnet.isMaskAllBitsOne())
removeAddress(subnet.getPrefix());
else
boost::range::remove_erase(subnets, subnet);
}
inline void AllowedClientHosts::addName(const String & name)
{
if (boost::iequals(name, "localhost"))
local_host = true;
else if (boost::range::find(names, name) == names.end())
names.push_back(name);
}
inline void AllowedClientHosts::removeName(const String & name)
{
if (boost::iequals(name, "localhost"))
local_host = false;
else
boost::range::remove_erase(names, name);
}
inline void AllowedClientHosts::addNameRegexp(const String & name_regexp)
{
if (boost::iequals(name_regexp, "localhost"))
local_host = true;
else if (name_regexp == ".*")
any_host = true;
else if (boost::range::find(name_regexps, name_regexp) == name_regexps.end())
name_regexps.push_back(name_regexp);
}
inline void AllowedClientHosts::removeNameRegexp(const String & name_regexp)
{
if (boost::iequals(name_regexp, "localhost"))
local_host = false;
else if (name_regexp == ".*")
any_host = false;
else
boost::range::remove_erase(name_regexps, name_regexp);
}
inline void AllowedClientHosts::addLikePattern(const String & pattern)
{
if (boost::iequals(pattern, "localhost") || (pattern == "127.0.0.1") || (pattern == "::1"))
local_host = true;
else if ((pattern == "@") || (pattern == "0.0.0.0/0") || (pattern == "::/0"))
any_host = true;
else if (boost::range::find(like_patterns, pattern) == name_regexps.end())
like_patterns.push_back(pattern);
}
inline void AllowedClientHosts::removeLikePattern(const String & pattern)
{
if (boost::iequals(pattern, "localhost") || (pattern == "127.0.0.1") || (pattern == "::1"))
local_host = false;
else if ((pattern == "@") || (pattern == "0.0.0.0/0") || (pattern == "::/0"))
any_host = false;
else
boost::range::remove_erase(like_patterns, pattern);
}
inline void AllowedClientHosts::addLocalHost()
{
local_host = true;
}
inline void AllowedClientHosts::removeLocalHost()
{
local_host = false;
}
inline void AllowedClientHosts::addAnyHost()
{
clear();
any_host = true;
}
inline void AllowedClientHosts::add(const AllowedClientHosts & other)
{
if (other.containsAnyHost())
{
addAnyHost();
return;
}
if (other.containsLocalHost())
addLocalHost();
for (const IPAddress & address : other.getAddresses())
addAddress(address);
for (const IPSubnet & subnet : other.getSubnets())
addSubnet(subnet);
for (const String & name : other.getNames())
addName(name);
for (const String & name_regexp : other.getNameRegexps())
addNameRegexp(name_regexp);
for (const String & like_pattern : other.getLikePatterns())
addLikePattern(like_pattern);
}
inline void AllowedClientHosts::remove(const AllowedClientHosts & other)
{
if (other.containsAnyHost())
{
clear();
return;
}
if (other.containsLocalHost())
removeLocalHost();
for (const IPAddress & address : other.getAddresses())
removeAddress(address);
for (const IPSubnet & subnet : other.getSubnets())
removeSubnet(subnet);
for (const String & name : other.getNames())
removeName(name);
for (const String & name_regexp : other.getNameRegexps())
removeNameRegexp(name_regexp);
for (const String & like_pattern : other.getLikePatterns())
removeLikePattern(like_pattern);
}
inline bool operator ==(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs)
{
return (lhs.any_host == rhs.any_host) && (lhs.local_host == rhs.local_host) && (lhs.addresses == rhs.addresses)
&& (lhs.subnets == rhs.subnets) && (lhs.names == rhs.names) && (lhs.name_regexps == rhs.name_regexps)
&& (lhs.like_patterns == rhs.like_patterns);
}
}

View File

@ -1,166 +1,18 @@
#include <Access/Authentication.h>
#include <Common/Exception.h>
#include <common/StringRef.h>
#include <Core/Defines.h>
#include <Poco/SHA1Engine.h>
#include <boost/algorithm/hex.hpp>
#include "config_core.h"
#if USE_SSL
# include <openssl/sha.h>
#endif
namespace DB
{
namespace ErrorCodes
{
extern const int SUPPORT_IS_DISABLED;
extern const int REQUIRED_PASSWORD;
extern const int WRONG_PASSWORD;
extern const int BAD_ARGUMENTS;
extern const int LOGICAL_ERROR;
}
namespace
{
using Digest = Authentication::Digest;
Digest encodePlainText(const StringRef & text)
{
return Digest(text.data, text.data + text.size);
}
Digest encodeSHA256(const StringRef & text)
{
#if USE_SSL
Digest hash;
hash.resize(32);
SHA256_CTX ctx;
SHA256_Init(&ctx);
SHA256_Update(&ctx, reinterpret_cast<const UInt8 *>(text.data), text.size);
SHA256_Final(hash.data(), &ctx);
return hash;
#else
UNUSED(text);
throw DB::Exception("SHA256 passwords support is disabled, because ClickHouse was built without SSL library", DB::ErrorCodes::SUPPORT_IS_DISABLED);
#endif
}
Digest encodeSHA1(const StringRef & text)
{
Poco::SHA1Engine engine;
engine.update(text.data, text.size);
return engine.digest();
}
Digest encodeSHA1(const Digest & text)
{
return encodeSHA1(StringRef{reinterpret_cast<const char *>(text.data()), text.size()});
}
Digest encodeDoubleSHA1(const StringRef & text)
{
return encodeSHA1(encodeSHA1(text));
}
}
Authentication::Authentication(Authentication::Type type_)
: type(type_)
{
}
void Authentication::setPassword(const String & password_)
{
switch (type)
{
case NO_PASSWORD:
throw Exception("Cannot specify password for the 'NO_PASSWORD' authentication type", ErrorCodes::LOGICAL_ERROR);
case PLAINTEXT_PASSWORD:
setPasswordHashBinary(encodePlainText(password_));
return;
case SHA256_PASSWORD:
setPasswordHashBinary(encodeSHA256(password_));
return;
case DOUBLE_SHA1_PASSWORD:
setPasswordHashBinary(encodeDoubleSHA1(password_));
return;
}
throw Exception("Unknown authentication type: " + std::to_string(static_cast<int>(type)), ErrorCodes::LOGICAL_ERROR);
}
String Authentication::getPassword() const
{
if (type != PLAINTEXT_PASSWORD)
throw Exception("Cannot decode the password", ErrorCodes::LOGICAL_ERROR);
return String(password_hash.data(), password_hash.data() + password_hash.size());
}
void Authentication::setPasswordHashHex(const String & hash)
{
Digest digest;
digest.resize(hash.size() / 2);
boost::algorithm::unhex(hash.begin(), hash.end(), digest.data());
setPasswordHashBinary(digest);
}
String Authentication::getPasswordHashHex() const
{
String hex;
hex.resize(password_hash.size() * 2);
boost::algorithm::hex(password_hash.begin(), password_hash.end(), hex.data());
return hex;
}
void Authentication::setPasswordHashBinary(const Digest & hash)
{
switch (type)
{
case NO_PASSWORD:
throw Exception("Cannot specify password for the 'NO_PASSWORD' authentication type", ErrorCodes::LOGICAL_ERROR);
case PLAINTEXT_PASSWORD:
{
password_hash = hash;
return;
}
case SHA256_PASSWORD:
{
if (hash.size() != 32)
throw Exception(
"Password hash for the 'SHA256_PASSWORD' authentication type has length " + std::to_string(hash.size())
+ " but must be exactly 32 bytes.",
ErrorCodes::BAD_ARGUMENTS);
password_hash = hash;
return;
}
case DOUBLE_SHA1_PASSWORD:
{
if (hash.size() != 20)
throw Exception(
"Password hash for the 'DOUBLE_SHA1_PASSWORD' authentication type has length " + std::to_string(hash.size())
+ " but must be exactly 20 bytes.",
ErrorCodes::BAD_ARGUMENTS);
password_hash = hash;
return;
}
}
throw Exception("Unknown authentication type: " + std::to_string(static_cast<int>(type)), ErrorCodes::LOGICAL_ERROR);
}
Digest Authentication::getPasswordDoubleSHA1() const
Authentication::Digest Authentication::getPasswordDoubleSHA1() const
{
switch (type)
{
@ -198,12 +50,12 @@ bool Authentication::isCorrectPassword(const String & password_) const
case PLAINTEXT_PASSWORD:
{
if (password_ == StringRef{reinterpret_cast<const char *>(password_hash.data()), password_hash.size()})
if (password_ == std::string_view{reinterpret_cast<const char *>(password_hash.data()), password_hash.size()})
return true;
// For compatibility with MySQL clients which support only native authentication plugin, SHA1 can be passed instead of password.
auto password_sha1 = encodeSHA1(password_hash);
return password_ == StringRef{reinterpret_cast<const char *>(password_sha1.data()), password_sha1.size()};
return password_ == std::string_view{reinterpret_cast<const char *>(password_sha1.data()), password_sha1.size()};
}
case SHA256_PASSWORD:
@ -234,10 +86,5 @@ void Authentication::checkPassword(const String & password_, const String & user
throw Exception("Wrong password" + info_about_user_name(), ErrorCodes::WRONG_PASSWORD);
}
bool operator ==(const Authentication & lhs, const Authentication & rhs)
{
return (lhs.type == rhs.type) && (lhs.password_hash == rhs.password_hash);
}
}

View File

@ -1,10 +1,22 @@
#pragma once
#include <Core/Types.h>
#include <Common/Exception.h>
#include <Common/OpenSSLHelpers.h>
#include <Poco/SHA1Engine.h>
#include <boost/algorithm/hex.hpp>
namespace DB
{
namespace ErrorCodes
{
extern const int SUPPORT_IS_DISABLED;
extern const int BAD_ARGUMENTS;
extern const int LOGICAL_ERROR;
}
/// Authentication type and encrypted password for checking when an user logins.
class Authentication
{
@ -27,7 +39,7 @@ public:
using Digest = std::vector<UInt8>;
Authentication(Authentication::Type type = NO_PASSWORD);
Authentication(Authentication::Type type_ = NO_PASSWORD) : type(type_) {}
Authentication(const Authentication & src) = default;
Authentication & operator =(const Authentication & src) = default;
Authentication(Authentication && src) = default;
@ -36,17 +48,19 @@ public:
Type getType() const { return type; }
/// Sets the password and encrypt it using the authentication type set in the constructor.
void setPassword(const String & password);
void setPassword(const String & password_);
/// Returns the password. Allowed to use only for Type::PLAINTEXT_PASSWORD.
String getPassword() const;
/// Sets the password as a string of hexadecimal digits.
void setPasswordHashHex(const String & hash);
String getPasswordHashHex() const;
/// Sets the password in binary form.
void setPasswordHashBinary(const Digest & hash);
const Digest & getPasswordHashBinary() const { return password_hash; }
/// Returns SHA1(SHA1(password)) used by MySQL compatibility server for authentication.
@ -60,11 +74,124 @@ public:
/// `user_name` is only used for generating an error message if the password is incorrect.
void checkPassword(const String & password, const String & user_name = String()) const;
friend bool operator ==(const Authentication & lhs, const Authentication & rhs);
friend bool operator ==(const Authentication & lhs, const Authentication & rhs) { return (lhs.type == rhs.type) && (lhs.password_hash == rhs.password_hash); }
friend bool operator !=(const Authentication & lhs, const Authentication & rhs) { return !(lhs == rhs); }
private:
static Digest encodePlainText(const std::string_view & text) { return Digest(text.data(), text.data() + text.size()); }
static Digest encodeSHA256(const std::string_view & text);
static Digest encodeSHA1(const std::string_view & text);
static Digest encodeSHA1(const Digest & text) { return encodeSHA1(std::string_view{reinterpret_cast<const char *>(text.data()), text.size()}); }
static Digest encodeDoubleSHA1(const std::string_view & text) { return encodeSHA1(encodeSHA1(text)); }
Type type = Type::NO_PASSWORD;
Digest password_hash;
};
inline Authentication::Digest Authentication::encodeSHA256(const std::string_view & text [[maybe_unused]])
{
#if USE_SSL
Digest hash;
hash.resize(32);
::DB::encodeSHA256(text, hash.data());
return hash;
#else
throw DB::Exception(
"SHA256 passwords support is disabled, because ClickHouse was built without SSL library",
DB::ErrorCodes::SUPPORT_IS_DISABLED);
#endif
}
inline Authentication::Digest Authentication::encodeSHA1(const std::string_view & text)
{
Poco::SHA1Engine engine;
engine.update(text.data(), text.size());
return engine.digest();
}
inline void Authentication::setPassword(const String & password_)
{
switch (type)
{
case NO_PASSWORD:
throw Exception("Cannot specify password for the 'NO_PASSWORD' authentication type", ErrorCodes::LOGICAL_ERROR);
case PLAINTEXT_PASSWORD:
return setPasswordHashBinary(encodePlainText(password_));
case SHA256_PASSWORD:
return setPasswordHashBinary(encodeSHA256(password_));
case DOUBLE_SHA1_PASSWORD:
return setPasswordHashBinary(encodeDoubleSHA1(password_));
}
throw Exception("Unknown authentication type: " + std::to_string(static_cast<int>(type)), ErrorCodes::LOGICAL_ERROR);
}
inline String Authentication::getPassword() const
{
if (type != PLAINTEXT_PASSWORD)
throw Exception("Cannot decode the password", ErrorCodes::LOGICAL_ERROR);
return String(password_hash.data(), password_hash.data() + password_hash.size());
}
inline void Authentication::setPasswordHashHex(const String & hash)
{
Digest digest;
digest.resize(hash.size() / 2);
boost::algorithm::unhex(hash.begin(), hash.end(), digest.data());
setPasswordHashBinary(digest);
}
inline String Authentication::getPasswordHashHex() const
{
String hex;
hex.resize(password_hash.size() * 2);
boost::algorithm::hex(password_hash.begin(), password_hash.end(), hex.data());
return hex;
}
inline void Authentication::setPasswordHashBinary(const Digest & hash)
{
switch (type)
{
case NO_PASSWORD:
throw Exception("Cannot specify password for the 'NO_PASSWORD' authentication type", ErrorCodes::LOGICAL_ERROR);
case PLAINTEXT_PASSWORD:
{
password_hash = hash;
return;
}
case SHA256_PASSWORD:
{
if (hash.size() != 32)
throw Exception(
"Password hash for the 'SHA256_PASSWORD' authentication type has length " + std::to_string(hash.size())
+ " but must be exactly 32 bytes.",
ErrorCodes::BAD_ARGUMENTS);
password_hash = hash;
return;
}
case DOUBLE_SHA1_PASSWORD:
{
if (hash.size() != 20)
throw Exception(
"Password hash for the 'DOUBLE_SHA1_PASSWORD' authentication type has length " + std::to_string(hash.size())
+ " but must be exactly 20 bytes.",
ErrorCodes::BAD_ARGUMENTS);
password_hash = hash;
return;
}
}
throw Exception("Unknown authentication type: " + std::to_string(static_cast<int>(type)), ErrorCodes::LOGICAL_ERROR);
}
}

View File

@ -10,7 +10,8 @@ bool User::equal(const IAccessEntity & other) const
return false;
const auto & other_user = typeid_cast<const User &>(other);
return (authentication == other_user.authentication) && (allowed_client_hosts == other_user.allowed_client_hosts)
&& (access == other_user.access) && (profile == other_user.profile);
&& (access == other_user.access) && (access_with_grant_option == other_user.access_with_grant_option)
&& (profile == other_user.profile);
}
}

View File

@ -14,8 +14,9 @@ namespace DB
struct User : public IAccessEntity
{
Authentication authentication;
AllowedClientHosts allowed_client_hosts;
AllowedClientHosts allowed_client_hosts{AllowedClientHosts::AnyHostTag{}};
AccessRights access;
AccessRights access_with_grant_option;
String profile;
bool equal(const IAccessEntity & other) const override;

View File

@ -90,15 +90,16 @@ namespace
{
Poco::Util::AbstractConfiguration::Keys keys;
config.keys(networks_config, keys);
user->allowed_client_hosts.clear();
for (const String & key : keys)
{
String value = config.getString(networks_config + "." + key);
if (key.starts_with("ip"))
user->allowed_client_hosts.addSubnet(value);
else if (key.starts_with("host_regexp"))
user->allowed_client_hosts.addHostRegexp(value);
user->allowed_client_hosts.addNameRegexp(value);
else if (key.starts_with("host"))
user->allowed_client_hosts.addHostName(value);
user->allowed_client_hosts.addName(value);
else
throw Exception("Unknown address pattern type: " + key, ErrorCodes::UNKNOWN_ADDRESS_PATTERN_TYPE);
}
@ -143,7 +144,6 @@ namespace
user->access.fullRevoke(AccessFlags::databaseLevel());
for (const String & database : *databases)
user->access.grant(AccessFlags::databaseLevel(), database);
user->access.grant(AccessFlags::databaseLevel(), "system"); /// Anyone has access to the "system" database.
}
if (dictionaries)
@ -155,6 +155,8 @@ namespace
else if (databases)
user->access.grant(AccessType::dictGet, IDictionary::NO_DATABASE_TAG);
user->access_with_grant_option = user->access;
return user;
}

View File

@ -74,7 +74,8 @@ void Connection::connect(const ConnectionTimeouts & timeouts)
current_resolved_address = DNSResolver::instance().resolveAddress(host, port);
socket->connect(*current_resolved_address, timeouts.connection_timeout);
const auto & connection_timeout = static_cast<bool>(secure) ? timeouts.secure_connection_timeout : timeouts.connection_timeout;
socket->connect(*current_resolved_address, connection_timeout);
socket->setReceiveTimeout(timeouts.receive_timeout);
socket->setSendTimeout(timeouts.send_timeout);
socket->setNoDelay(true);

View File

@ -15,7 +15,8 @@
M(DiskSpaceReservedForMerge, "Disk space reserved for currently running background merges. It is slightly more than the total size of currently merging parts.") \
M(DistributedSend, "Number of connections to remote servers sending data that was INSERTed into Distributed tables. Both synchronous and asynchronous mode.") \
M(QueryPreempted, "Number of queries that are stopped and waiting due to 'priority' setting.") \
M(TCPConnection, "Number of connections to TCP server (clients with native interface)") \
M(TCPConnection, "Number of connections to TCP server (clients with native interface), also included server-server distributed query connections") \
M(MySQLConnection, "Number of client connections using MySQL protocol") \
M(HTTPConnection, "Number of connections to HTTP server") \
M(InterserverConnection, "Number of connections from other replicas to fetch parts") \
M(OpenFileForRead, "Number of files open for reading") \

View File

@ -481,6 +481,7 @@ namespace ErrorCodes
extern const int UNABLE_TO_SKIP_UNUSED_SHARDS = 507;
extern const int UNKNOWN_ACCESS_TYPE = 508;
extern const int INVALID_GRANT = 509;
extern const int CACHE_DICTIONARY_UPDATE_FAIL = 510;
extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000;

View File

@ -1,12 +1,17 @@
#include <cstdlib>
#include "MemoryTracker.h"
#include <common/likely.h>
#include <common/logger_useful.h>
#include <IO/WriteHelpers.h>
#include "Common/TraceCollector.h"
#include <Common/CurrentThread.h>
#include <Common/Exception.h>
#include <Common/formatReadable.h>
#include <Common/CurrentThread.h>
#include <IO/WriteHelpers.h>
#include <common/likely.h>
#include <common/logger_useful.h>
#include <ext/singleton.h>
#include <atomic>
#include <cmath>
#include <cstdlib>
namespace DB
@ -73,7 +78,7 @@ void MemoryTracker::alloc(Int64 size)
return;
/** Using memory_order_relaxed means that if allocations are done simultaneously,
* we allow exception about memory limit exceeded to be thrown only on next allocation.
* we allow exception about memory limit exceeded to be thrown only on next allocation.
* So, we allow over-allocations.
*/
Int64 will_be = size + amount.fetch_add(size, std::memory_order_relaxed);
@ -81,7 +86,8 @@ void MemoryTracker::alloc(Int64 size)
if (metric != CurrentMetrics::end())
CurrentMetrics::add(metric, size);
Int64 current_limit = limit.load(std::memory_order_relaxed);
Int64 current_hard_limit = hard_limit.load(std::memory_order_relaxed);
Int64 current_profiler_limit = profiler_limit.load(std::memory_order_relaxed);
/// Using non-thread-safe random number generator. Joint distribution in different threads would not be uniform.
/// In this case, it doesn't matter.
@ -98,12 +104,19 @@ void MemoryTracker::alloc(Int64 size)
message << " " << description;
message << ": fault injected. Would use " << formatReadableSizeWithBinarySuffix(will_be)
<< " (attempt to allocate chunk of " << size << " bytes)"
<< ", maximum: " << formatReadableSizeWithBinarySuffix(current_limit);
<< ", maximum: " << formatReadableSizeWithBinarySuffix(current_hard_limit);
throw DB::Exception(message.str(), DB::ErrorCodes::MEMORY_LIMIT_EXCEEDED);
}
if (unlikely(current_limit && will_be > current_limit))
if (unlikely(current_profiler_limit && will_be > current_profiler_limit))
{
auto no_track = blocker.cancel();
ext::Singleton<DB::TraceCollector>()->collect(size);
setOrRaiseProfilerLimit(current_profiler_limit + Int64(std::ceil((will_be - current_profiler_limit) / profiler_step)) * profiler_step);
}
if (unlikely(current_hard_limit && will_be > current_hard_limit))
{
free(size);
@ -116,7 +129,7 @@ void MemoryTracker::alloc(Int64 size)
message << " " << description;
message << " exceeded: would use " << formatReadableSizeWithBinarySuffix(will_be)
<< " (attempt to allocate chunk of " << size << " bytes)"
<< ", maximum: " << formatReadableSizeWithBinarySuffix(current_limit);
<< ", maximum: " << formatReadableSizeWithBinarySuffix(current_hard_limit);
throw DB::Exception(message.str(), DB::ErrorCodes::MEMORY_LIMIT_EXCEEDED);
}
@ -174,7 +187,8 @@ void MemoryTracker::resetCounters()
{
amount.store(0, std::memory_order_relaxed);
peak.store(0, std::memory_order_relaxed);
limit.store(0, std::memory_order_relaxed);
hard_limit.store(0, std::memory_order_relaxed);
profiler_limit.store(0, std::memory_order_relaxed);
}
@ -187,11 +201,20 @@ void MemoryTracker::reset()
}
void MemoryTracker::setOrRaiseLimit(Int64 value)
void MemoryTracker::setOrRaiseHardLimit(Int64 value)
{
/// This is just atomic set to maximum.
Int64 old_value = limit.load(std::memory_order_relaxed);
while (old_value < value && !limit.compare_exchange_weak(old_value, value))
Int64 old_value = hard_limit.load(std::memory_order_relaxed);
while (old_value < value && !hard_limit.compare_exchange_weak(old_value, value))
;
}
void MemoryTracker::setOrRaiseProfilerLimit(Int64 value)
{
/// This is just atomic set to maximum.
Int64 old_value = profiler_limit.load(std::memory_order_relaxed);
while (old_value < value && !profiler_limit.compare_exchange_weak(old_value, value))
;
}
@ -207,7 +230,7 @@ namespace CurrentMemoryTracker
if (untracked > untracked_memory_limit)
{
/// Zero untracked before track. If tracker throws out-of-limit we would be able to alloc up to untracked_memory_limit bytes
/// more. It could be usefull for enlarge Exception message in rethrow logic.
/// more. It could be useful to enlarge Exception message in rethrow logic.
Int64 tmp = untracked;
untracked = 0;
memory_tracker->alloc(tmp);
@ -218,10 +241,7 @@ namespace CurrentMemoryTracker
void realloc(Int64 old_size, Int64 new_size)
{
Int64 addition = new_size - old_size;
if (addition > 0)
alloc(addition);
else
free(-addition);
addition > 0 ? alloc(addition) : free(-addition);
}
void free(Int64 size)

View File

@ -15,7 +15,10 @@ class MemoryTracker
{
std::atomic<Int64> amount {0};
std::atomic<Int64> peak {0};
std::atomic<Int64> limit {0};
std::atomic<Int64> hard_limit {0};
std::atomic<Int64> profiler_limit {0};
Int64 profiler_step = 0;
/// To test exception safety of calling code, memory tracker throws an exception on each memory allocation with specified probability.
double fault_probability = 0;
@ -32,7 +35,6 @@ class MemoryTracker
public:
MemoryTracker(VariableContext level_ = VariableContext::Thread) : level(level_) {}
MemoryTracker(Int64 limit_, VariableContext level_ = VariableContext::Thread) : limit(limit_), level(level_) {}
MemoryTracker(MemoryTracker * parent_, VariableContext level_ = VariableContext::Thread) : parent(parent_), level(level_) {}
~MemoryTracker();
@ -66,21 +68,22 @@ public:
return peak.load(std::memory_order_relaxed);
}
void setLimit(Int64 limit_)
{
limit.store(limit_, std::memory_order_relaxed);
}
/** Set limit if it was not set.
* Otherwise, set limit to new value, if new value is greater than previous limit.
*/
void setOrRaiseLimit(Int64 value);
void setOrRaiseHardLimit(Int64 value);
void setOrRaiseProfilerLimit(Int64 value);
void setFaultProbability(double value)
{
fault_probability = value;
}
void setProfilerStep(Int64 value)
{
profiler_step = value;
}
/// next should be changed only once: from nullptr to some value.
/// NOTE: It is not true in MergeListElement
void setParent(MemoryTracker * elem)

View File

@ -3,11 +3,20 @@
#include "OpenSSLHelpers.h"
#include <ext/scope_guard.h>
#include <openssl/err.h>
#include <openssl/sha.h>
namespace DB
{
#pragma GCC diagnostic warning "-Wold-style-cast"
void encodeSHA256(const std::string_view & text, unsigned char * out)
{
SHA256_CTX ctx;
SHA256_Init(&ctx);
SHA256_Update(&ctx, reinterpret_cast<const UInt8 *>(text.data()), text.size());
SHA256_Final(out, &ctx);
}
String getOpenSSLErrors()
{
BIO * mem = BIO_new(BIO_s_mem());

View File

@ -7,6 +7,8 @@
namespace DB
{
/// Encodes `text` and puts the result to `out` which must be at least 32 bytes long.
void encodeSHA256(const std::string_view & text, unsigned char * out);
/// Returns concatenation of error strings for all errors that OpenSSL has recorded, emptying the error queue.
String getOpenSSLErrors();

View File

@ -1,92 +1,38 @@
#include "QueryProfiler.h"
#include <random>
#include <common/phdr_cache.h>
#include <common/config_common.h>
#include <common/StringRef.h>
#include <common/logger_useful.h>
#include <Common/PipeFDs.h>
#include <Common/StackTrace.h>
#include <Common/CurrentThread.h>
#include <Common/Exception.h>
#include <Common/thread_local_rng.h>
#include <IO/WriteHelpers.h>
#include <IO/WriteBufferFromFileDescriptorDiscardOnFailure.h>
#include <Common/Exception.h>
#include <Common/StackTrace.h>
#include <Common/TraceCollector.h>
#include <Common/thread_local_rng.h>
#include <common/StringRef.h>
#include <common/config_common.h>
#include <common/logger_useful.h>
#include <common/phdr_cache.h>
#include <ext/singleton.h>
#include <random>
namespace ProfileEvents
{
extern const Event QueryProfilerSignalOverruns;
}
namespace DB
{
extern LazyPipeFDs trace_pipe;
namespace
{
/// Normally query_id is a UUID (string with a fixed length) but user can provide custom query_id.
/// Thus upper bound on query_id length should be introduced to avoid buffer overflow in signal handler.
constexpr size_t QUERY_ID_MAX_LEN = 1024;
#if defined(OS_LINUX)
thread_local size_t write_trace_iteration = 0;
#endif
void writeTraceInfo(TimerType timer_type, int /* sig */, siginfo_t * info, void * context)
void writeTraceInfo(TraceType trace_type, int /* sig */, siginfo_t * info, void * context)
{
int overrun_count = 0;
#if defined(OS_LINUX)
/// Quickly drop if signal handler is called too frequently.
/// Otherwise we may end up infinitelly processing signals instead of doing any useful work.
++write_trace_iteration;
if (info && info->si_overrun > 0)
{
/// But pass with some frequency to avoid drop of all traces.
if (write_trace_iteration % info->si_overrun == 0)
{
ProfileEvents::increment(ProfileEvents::QueryProfilerSignalOverruns, info->si_overrun);
}
else
{
ProfileEvents::increment(ProfileEvents::QueryProfilerSignalOverruns, info->si_overrun + 1);
return;
}
}
if (info)
overrun_count = info->si_overrun;
#else
UNUSED(info);
#endif
constexpr size_t buf_size = sizeof(char) + // TraceCollector stop flag
8 * sizeof(char) + // maximum VarUInt length for string size
QUERY_ID_MAX_LEN * sizeof(char) + // maximum query_id length
sizeof(UInt8) + // number of stack frames
sizeof(StackTrace::Frames) + // collected stack trace, maximum capacity
sizeof(TimerType) + // timer type
sizeof(UInt64); // thread_id
char buffer[buf_size];
WriteBufferFromFileDescriptorDiscardOnFailure out(trace_pipe.fds_rw[1], buf_size, buffer);
StringRef query_id = CurrentThread::getQueryId();
query_id.size = std::min(query_id.size, QUERY_ID_MAX_LEN);
UInt64 thread_id = CurrentThread::get().thread_id;
const auto signal_context = *reinterpret_cast<ucontext_t *>(context);
const StackTrace stack_trace(signal_context);
writeChar(false, out);
writeStringBinary(query_id, out);
size_t stack_trace_size = stack_trace.getSize();
size_t stack_trace_offset = stack_trace.getOffset();
writeIntBinary(UInt8(stack_trace_size - stack_trace_offset), out);
for (size_t i = stack_trace_offset; i < stack_trace_size; ++i)
writePODBinary(stack_trace.getFrames()[i], out);
writePODBinary(timer_type, out);
writePODBinary(thread_id, out);
out.next();
ext::Singleton<TraceCollector>()->collect(trace_type, stack_trace, overrun_count);
}
[[maybe_unused]] const UInt32 TIMER_PRECISION = 1e9;
@ -135,11 +81,11 @@ QueryProfilerBase<ProfilerImpl>::QueryProfilerBase(const UInt64 thread_id, const
sev.sigev_notify = SIGEV_THREAD_ID;
sev.sigev_signo = pause_signal;
#if defined(__FreeBSD__)
# if defined(__FreeBSD__)
sev._sigev_un._threadid = thread_id;
#else
# else
sev._sigev_un._tid = thread_id;
#endif
# endif
if (timer_create(clock_type, &sev, &timer_id))
{
/// In Google Cloud Run, the function "timer_create" is implemented incorrectly as of 2020-01-25.
@ -206,7 +152,7 @@ QueryProfilerReal::QueryProfilerReal(const UInt64 thread_id, const UInt32 period
void QueryProfilerReal::signalHandler(int sig, siginfo_t * info, void * context)
{
writeTraceInfo(TimerType::Real, sig, info, context);
writeTraceInfo(TraceType::REAL_TIME, sig, info, context);
}
QueryProfilerCpu::QueryProfilerCpu(const UInt64 thread_id, const UInt32 period)
@ -215,7 +161,7 @@ QueryProfilerCpu::QueryProfilerCpu(const UInt64 thread_id, const UInt32 period)
void QueryProfilerCpu::signalHandler(int sig, siginfo_t * info, void * context)
{
writeTraceInfo(TimerType::Cpu, sig, info, context);
writeTraceInfo(TraceType::CPU_TIME, sig, info, context);
}
}

View File

@ -15,12 +15,6 @@ namespace Poco
namespace DB
{
enum class TimerType : UInt8
{
Real,
Cpu,
};
/**
* Query profiler implementation for selected thread.
*

View File

@ -1,25 +1,38 @@
#include "TraceCollector.h"
#include <Core/Field.h>
#include <IO/ReadBufferFromFileDescriptor.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteBufferFromFileDescriptor.h>
#include <IO/WriteBufferFromFileDescriptorDiscardOnFailure.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/TraceLog.h>
#include <Poco/Logger.h>
#include <Common/Exception.h>
#include <Common/PipeFDs.h>
#include <Common/StackTrace.h>
#include <common/logger_useful.h>
#include <IO/ReadHelpers.h>
#include <IO/ReadBufferFromFileDescriptor.h>
#include <IO/WriteHelpers.h>
#include <IO/WriteBufferFromFileDescriptor.h>
#include <Common/Exception.h>
#include <Interpreters/TraceLog.h>
#include <unistd.h>
#include <fcntl.h>
namespace ProfileEvents
{
extern const Event QueryProfilerSignalOverruns;
}
namespace DB
{
LazyPipeFDs trace_pipe;
namespace
{
/// Normally query_id is a UUID (string with a fixed length) but user can provide custom query_id.
/// Thus upper bound on query_id length should be introduced to avoid buffer overflow in signal handler.
constexpr size_t QUERY_ID_MAX_LEN = 1024;
thread_local size_t write_trace_iteration = 0;
}
namespace ErrorCodes
{
@ -27,20 +40,15 @@ namespace ErrorCodes
extern const int THREAD_IS_NOT_JOINABLE;
}
TraceCollector::TraceCollector(std::shared_ptr<TraceLog> & trace_log_)
: log(&Poco::Logger::get("TraceCollector"))
, trace_log(trace_log_)
TraceCollector::TraceCollector()
{
if (trace_log == nullptr)
throw Exception("Invalid trace log pointer passed", ErrorCodes::NULL_POINTER_DEREFERENCE);
trace_pipe.open();
pipe.open();
/** Turn write end of pipe to non-blocking mode to avoid deadlocks
* when QueryProfiler is invoked under locks and TraceCollector cannot pull data from pipe.
*/
trace_pipe.setNonBlocking();
trace_pipe.tryIncreaseSize(1 << 20);
pipe.setNonBlocking();
pipe.tryIncreaseSize(1 << 20);
thread = ThreadFromGlobalPool(&TraceCollector::run, this);
}
@ -48,14 +56,101 @@ TraceCollector::TraceCollector(std::shared_ptr<TraceLog> & trace_log_)
TraceCollector::~TraceCollector()
{
if (!thread.joinable())
LOG_ERROR(log, "TraceCollector thread is malformed and cannot be joined");
LOG_ERROR(&Poco::Logger::get("TraceCollector"), "TraceCollector thread is malformed and cannot be joined");
else
{
TraceCollector::notifyToStop();
stop();
thread.join();
}
trace_pipe.close();
pipe.close();
}
void TraceCollector::collect(TraceType trace_type, const StackTrace & stack_trace, int overrun_count)
{
/// Quickly drop if signal handler is called too frequently.
/// Otherwise we may end up infinitelly processing signals instead of doing any useful work.
++write_trace_iteration;
if (overrun_count)
{
/// But pass with some frequency to avoid drop of all traces.
if (write_trace_iteration % overrun_count == 0)
{
ProfileEvents::increment(ProfileEvents::QueryProfilerSignalOverruns, overrun_count);
}
else
{
ProfileEvents::increment(ProfileEvents::QueryProfilerSignalOverruns, overrun_count + 1);
return;
}
}
constexpr size_t buf_size = sizeof(char) + // TraceCollector stop flag
8 * sizeof(char) + // maximum VarUInt length for string size
QUERY_ID_MAX_LEN * sizeof(char) + // maximum query_id length
sizeof(UInt8) + // number of stack frames
sizeof(StackTrace::Frames) + // collected stack trace, maximum capacity
sizeof(TraceType) + // trace type
sizeof(UInt64) + // thread_id
sizeof(UInt64); // size
char buffer[buf_size];
WriteBufferFromFileDescriptorDiscardOnFailure out(pipe.fds_rw[1], buf_size, buffer);
StringRef query_id = CurrentThread::getQueryId();
query_id.size = std::min(query_id.size, QUERY_ID_MAX_LEN);
auto thread_id = CurrentThread::get().thread_id;
writeChar(false, out);
writeStringBinary(query_id, out);
size_t stack_trace_size = stack_trace.getSize();
size_t stack_trace_offset = stack_trace.getOffset();
writeIntBinary(UInt8(stack_trace_size - stack_trace_offset), out);
for (size_t i = stack_trace_offset; i < stack_trace_size; ++i)
writePODBinary(stack_trace.getFrames()[i], out);
writePODBinary(trace_type, out);
writePODBinary(thread_id, out);
writePODBinary(UInt64(0), out);
out.next();
}
void TraceCollector::collect(UInt64 size)
{
constexpr size_t buf_size = sizeof(char) + // TraceCollector stop flag
8 * sizeof(char) + // maximum VarUInt length for string size
QUERY_ID_MAX_LEN * sizeof(char) + // maximum query_id length
sizeof(UInt8) + // number of stack frames
sizeof(StackTrace::Frames) + // collected stack trace, maximum capacity
sizeof(TraceType) + // trace type
sizeof(UInt64) + // thread_id
sizeof(UInt64); // size
char buffer[buf_size];
WriteBufferFromFileDescriptorDiscardOnFailure out(pipe.fds_rw[1], buf_size, buffer);
StringRef query_id = CurrentThread::getQueryId();
query_id.size = std::min(query_id.size, QUERY_ID_MAX_LEN);
auto thread_id = CurrentThread::get().thread_id;
writeChar(false, out);
writeStringBinary(query_id, out);
const auto & stack_trace = StackTrace();
size_t stack_trace_size = stack_trace.getSize();
size_t stack_trace_offset = stack_trace.getOffset();
writeIntBinary(UInt8(stack_trace_size - stack_trace_offset), out);
for (size_t i = stack_trace_offset; i < stack_trace_size; ++i)
writePODBinary(stack_trace.getFrames()[i], out);
writePODBinary(TraceType::MEMORY, out);
writePODBinary(thread_id, out);
writePODBinary(size, out);
out.next();
}
/**
@ -68,16 +163,16 @@ TraceCollector::~TraceCollector()
* NOTE: TraceCollector will NOT stop immediately as there may be some data left in the pipe
* before stop message.
*/
void TraceCollector::notifyToStop()
void TraceCollector::stop()
{
WriteBufferFromFileDescriptor out(trace_pipe.fds_rw[1]);
WriteBufferFromFileDescriptor out(pipe.fds_rw[1]);
writeChar(true, out);
out.next();
}
void TraceCollector::run()
{
ReadBufferFromFileDescriptor in(trace_pipe.fds_rw[0]);
ReadBufferFromFileDescriptor in(pipe.fds_rw[0]);
while (true)
{
@ -89,27 +184,33 @@ void TraceCollector::run()
std::string query_id;
readStringBinary(query_id, in);
UInt8 size = 0;
readIntBinary(size, in);
UInt8 trace_size = 0;
readIntBinary(trace_size, in);
Array trace;
trace.reserve(size);
trace.reserve(trace_size);
for (size_t i = 0; i < size; i++)
for (size_t i = 0; i < trace_size; i++)
{
uintptr_t addr = 0;
readPODBinary(addr, in);
trace.emplace_back(UInt64(addr));
}
TimerType timer_type;
readPODBinary(timer_type, in);
TraceType trace_type;
readPODBinary(trace_type, in);
UInt64 thread_id;
readPODBinary(thread_id, in);
TraceLogElement element{std::time(nullptr), timer_type, thread_id, query_id, trace};
trace_log->add(element);
UInt64 size;
readPODBinary(size, in);
if (trace_log)
{
TraceLogElement element{std::time(nullptr), trace_type, thread_id, query_id, trace, size};
trace_log->add(element);
}
}
}

View File

@ -1,7 +1,10 @@
#pragma once
#include "Common/PipeFDs.h"
#include <Common/ThreadPool.h>
class StackTrace;
namespace Poco
{
class Logger;
@ -12,21 +15,31 @@ namespace DB
class TraceLog;
enum class TraceType : UInt8
{
REAL_TIME,
CPU_TIME,
MEMORY,
};
class TraceCollector
{
public:
TraceCollector();
~TraceCollector();
void setTraceLog(const std::shared_ptr<TraceLog> & trace_log_) { trace_log = trace_log_; }
void collect(TraceType type, const StackTrace & stack_trace, int overrun_count = 0);
void collect(UInt64 size);
private:
Poco::Logger * log;
std::shared_ptr<TraceLog> trace_log;
ThreadFromGlobalPool thread;
LazyPipeFDs pipe;
void run();
static void notifyToStop();
public:
TraceCollector(std::shared_ptr<TraceLog> & trace_log_);
~TraceCollector();
void stop();
};
}

View File

@ -1,14 +1,16 @@
#if defined(OS_LINUX)
#include <malloc.h>
#elif defined(OS_DARWIN)
#include <malloc/malloc.h>
#endif
#include <new>
#include <common/config_common.h>
#include <common/memory.h>
#include <Common/MemoryTracker.h>
#include <iostream>
#include <new>
#if defined(OS_LINUX)
# include <malloc.h>
#elif defined(OS_DARWIN)
# include <malloc/malloc.h>
#endif
/// Replace default new/delete with memory tracking versions.
/// @sa https://en.cppreference.com/w/cpp/memory/new/operator_new
/// https://en.cppreference.com/w/cpp/memory/new/operator_delete
@ -29,7 +31,7 @@ ALWAYS_INLINE void trackMemory(std::size_t size)
#endif
}
ALWAYS_INLINE bool trackMemoryNoExept(std::size_t size) noexcept
ALWAYS_INLINE bool trackMemoryNoExcept(std::size_t size) noexcept
{
try
{
@ -54,11 +56,11 @@ ALWAYS_INLINE void untrackMemory(void * ptr [[maybe_unused]], std::size_t size [
#else
if (size)
CurrentMemoryTracker::free(size);
#ifdef _GNU_SOURCE
# ifdef _GNU_SOURCE
/// It's innaccurate resource free for sanitizers. malloc_usable_size() result is greater or equal to allocated size.
else
CurrentMemoryTracker::free(malloc_usable_size(ptr));
#endif
# endif
#endif
}
catch (...)
@ -83,14 +85,14 @@ void * operator new[](std::size_t size)
void * operator new(std::size_t size, const std::nothrow_t &) noexcept
{
if (likely(Memory::trackMemoryNoExept(size)))
if (likely(Memory::trackMemoryNoExcept(size)))
return Memory::newNoExept(size);
return nullptr;
}
void * operator new[](std::size_t size, const std::nothrow_t &) noexcept
{
if (likely(Memory::trackMemoryNoExept(size)))
if (likely(Memory::trackMemoryNoExcept(size)))
return Memory::newNoExept(size);
return nullptr;
}

View File

@ -0,0 +1,17 @@
#pragma once
#include <Interpreters/Context.h>
inline DB::Context createContext()
{
auto context = DB::Context::createGlobal();
context.makeGlobalContext();
context.setPath("./");
return context;
}
inline const DB::Context & getContext()
{
static DB::Context global_context = createContext();
return global_context;
}

View File

@ -241,30 +241,35 @@ void decompressDataForType(const char * source, UInt32 source_size, char * dest)
const char * source_end = source + source_size;
if (source + sizeof(UInt32) > source_end)
return;
const UInt32 items_count = unalignedLoad<UInt32>(source);
source += sizeof(items_count);
ValueType prev_value{};
UnsignedDeltaType prev_delta{};
if (source < source_end)
{
prev_value = unalignedLoad<ValueType>(source);
unalignedStore<ValueType>(dest, prev_value);
// decoding first item
if (source + sizeof(ValueType) > source_end || items_count < 1)
return;
source += sizeof(prev_value);
dest += sizeof(prev_value);
}
prev_value = unalignedLoad<ValueType>(source);
unalignedStore<ValueType>(dest, prev_value);
if (source < source_end)
{
prev_delta = unalignedLoad<UnsignedDeltaType>(source);
prev_value = prev_value + static_cast<ValueType>(prev_delta);
unalignedStore<ValueType>(dest, prev_value);
source += sizeof(prev_value);
dest += sizeof(prev_value);
source += sizeof(prev_delta);
dest += sizeof(prev_value);
}
// decoding second item
if (source + sizeof(UnsignedDeltaType) > source_end || items_count < 2)
return;
prev_delta = unalignedLoad<UnsignedDeltaType>(source);
prev_value = prev_value + static_cast<ValueType>(prev_delta);
unalignedStore<ValueType>(dest, prev_value);
source += sizeof(prev_delta);
dest += sizeof(prev_value);
BitReader reader(source, source_size - sizeof(prev_value) - sizeof(prev_delta) - sizeof(items_count));

View File

@ -159,19 +159,23 @@ void decompressDataForType(const char * source, UInt32 source_size, char * dest)
const char * source_end = source + source_size;
if (source + sizeof(UInt32) > source_end)
return;
const UInt32 items_count = unalignedLoad<UInt32>(source);
source += sizeof(items_count);
T prev_value{};
if (source < source_end)
{
prev_value = unalignedLoad<T>(source);
unalignedStore<T>(dest, prev_value);
// decoding first item
if (source + sizeof(T) > source_end || items_count < 1)
return;
source += sizeof(prev_value);
dest += sizeof(prev_value);
}
prev_value = unalignedLoad<T>(source);
unalignedStore<T>(dest, prev_value);
source += sizeof(prev_value);
dest += sizeof(prev_value);
BitReader reader(source, source_size - sizeof(items_count) - sizeof(prev_value));

View File

@ -23,6 +23,78 @@ extern const int LOGICAL_ERROR;
namespace
{
/// Fixed TypeIds that numbers would not be changed between versions.
enum class MagicNumber : uint8_t
{
UInt8 = 1,
UInt16 = 2,
UInt32 = 3,
UInt64 = 4,
Int8 = 6,
Int16 = 7,
Int32 = 8,
Int64 = 9,
Date = 13,
DateTime = 14,
DateTime64 = 15,
Enum8 = 17,
Enum16 = 18,
Decimal32 = 19,
Decimal64 = 20,
};
MagicNumber serializeTypeId(TypeIndex type_id)
{
switch (type_id)
{
case TypeIndex::UInt8: return MagicNumber::UInt8;
case TypeIndex::UInt16: return MagicNumber::UInt16;
case TypeIndex::UInt32: return MagicNumber::UInt32;
case TypeIndex::UInt64: return MagicNumber::UInt64;
case TypeIndex::Int8: return MagicNumber::Int8;
case TypeIndex::Int16: return MagicNumber::Int16;
case TypeIndex::Int32: return MagicNumber::Int32;
case TypeIndex::Int64: return MagicNumber::Int64;
case TypeIndex::Date: return MagicNumber::Date;
case TypeIndex::DateTime: return MagicNumber::DateTime;
case TypeIndex::DateTime64: return MagicNumber::DateTime64;
case TypeIndex::Enum8: return MagicNumber::Enum8;
case TypeIndex::Enum16: return MagicNumber::Enum16;
case TypeIndex::Decimal32: return MagicNumber::Decimal32;
case TypeIndex::Decimal64: return MagicNumber::Decimal64;
default:
break;
}
throw Exception("Type is not supported by T64 codec: " + toString(UInt32(type_id)), ErrorCodes::LOGICAL_ERROR);
}
TypeIndex deserializeTypeId(uint8_t serialized_type_id)
{
MagicNumber magic = static_cast<MagicNumber>(serialized_type_id);
switch (magic)
{
case MagicNumber::UInt8: return TypeIndex::UInt8;
case MagicNumber::UInt16: return TypeIndex::UInt16;
case MagicNumber::UInt32: return TypeIndex::UInt32;
case MagicNumber::UInt64: return TypeIndex::UInt64;
case MagicNumber::Int8: return TypeIndex::Int8;
case MagicNumber::Int16: return TypeIndex::Int16;
case MagicNumber::Int32: return TypeIndex::Int32;
case MagicNumber::Int64: return TypeIndex::Int64;
case MagicNumber::Date: return TypeIndex::Date;
case MagicNumber::DateTime: return TypeIndex::DateTime;
case MagicNumber::DateTime64: return TypeIndex::DateTime64;
case MagicNumber::Enum8: return TypeIndex::Enum8;
case MagicNumber::Enum16: return TypeIndex::Enum16;
case MagicNumber::Decimal32: return TypeIndex::Decimal32;
case MagicNumber::Decimal64: return TypeIndex::Decimal64;
}
throw Exception("Bad magic number in T64 codec: " + toString(UInt32(serialized_type_id)), ErrorCodes::LOGICAL_ERROR);
}
UInt8 codecId()
{
return static_cast<UInt8>(CompressionMethodByte::T64);
@ -41,6 +113,7 @@ TypeIndex baseType(TypeIndex type_idx)
return TypeIndex::Int32;
case TypeIndex::Int64:
case TypeIndex::Decimal64:
case TypeIndex::DateTime64:
return TypeIndex::Int64;
case TypeIndex::UInt8:
case TypeIndex::Enum8:
@ -79,6 +152,7 @@ TypeIndex typeIdx(const DataTypePtr & data_type)
case TypeIndex::Int32:
case TypeIndex::UInt32:
case TypeIndex::DateTime:
case TypeIndex::DateTime64:
case TypeIndex::Decimal32:
case TypeIndex::Int64:
case TypeIndex::UInt64:
@ -490,7 +564,7 @@ void decompressData(const char * src, UInt32 src_size, char * dst, UInt32 uncomp
UInt32 CompressionCodecT64::doCompressData(const char * src, UInt32 src_size, char * dst) const
{
UInt8 cookie = static_cast<UInt8>(type_idx) | (static_cast<UInt8>(variant) << 7);
UInt8 cookie = static_cast<UInt8>(serializeTypeId(type_idx)) | (static_cast<UInt8>(variant) << 7);
memcpy(dst, &cookie, 1);
dst += 1;
@ -529,7 +603,7 @@ void CompressionCodecT64::doDecompressData(const char * src, UInt32 src_size, ch
src_size -= 1;
auto saved_variant = static_cast<Variant>(cookie >> 7);
auto saved_type_id = static_cast<TypeIndex>(cookie & 0x7F);
TypeIndex saved_type_id = deserializeTypeId(cookie & 0x7F);
switch (baseType(saved_type_id))
{

View File

@ -158,8 +158,8 @@ public:
explicit BinaryDataAsSequenceOfValuesIterator(const Container & container_)
: container(container_),
data(&container[0]),
data_end(reinterpret_cast<const char *>(data) + container.size()),
data(container.data()),
data_end(container.data() + container.size()),
current_value(T{})
{
static_assert(sizeof(container[0]) == 1 && std::is_pod<std::decay_t<decltype(container[0])>>::value, "Only works on containers of byte-size PODs.");
@ -789,12 +789,14 @@ auto FFand0Generator = []()
};
// Makes many sequences with generator, first sequence length is 1, second is 2... up to `sequences_count`.
// Makes many sequences with generator, first sequence length is 0, second is 1..., third is 2 up to `sequences_count`.
template <typename T, typename Generator>
std::vector<CodecTestSequence> generatePyramidOfSequences(const size_t sequences_count, Generator && generator, const char* generator_name)
{
std::vector<CodecTestSequence> sequences;
sequences.reserve(sequences_count);
sequences.push_back(makeSeq<T>()); // sequence of size 0
for (size_t i = 1; i < sequences_count; ++i)
{
std::string name = generator_name + std::string(" from 0 to ") + std::to_string(i);

View File

@ -6,6 +6,7 @@
#define DBMS_DEFAULT_HTTP_PORT 8123
#define DBMS_DEFAULT_CONNECT_TIMEOUT_SEC 10
#define DBMS_DEFAULT_CONNECT_TIMEOUT_WITH_FAILOVER_MS 50
#define DBMS_DEFAULT_CONNECT_TIMEOUT_WITH_FAILOVER_SECURE_MS 100
#define DBMS_DEFAULT_SEND_TIMEOUT_SEC 300
#define DBMS_DEFAULT_RECEIVE_TIMEOUT_SEC 300
/// Timeout for synchronous request-result protocol call (like Ping or TablesStatus).

View File

@ -7,6 +7,7 @@
#include <Common/PODArray.h>
#include <Core/Types.h>
#include <Interpreters/Context.h>
#include <Access/AccessControlManager.h>
#include <Access/User.h>
#include <IO/copyData.h>
#include <IO/LimitReadBuffer.h>
@ -952,7 +953,7 @@ public:
throw Exception("Wrong size of auth response. Expected: " + std::to_string(Poco::SHA1Engine::DIGEST_SIZE) + " bytes, received: " + std::to_string(auth_response->size()) + " bytes.",
ErrorCodes::UNKNOWN_EXCEPTION);
auto user = context.getUser(user_name);
auto user = context.getAccessControlManager().getUser(user_name);
Poco::SHA1Engine::Digest double_sha1_value = user->authentication.getPasswordDoubleSHA1();
assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE);

View File

@ -62,6 +62,7 @@ struct Settings : public SettingsCollection<Settings>
M(SettingUInt64, interactive_delay, 100000, "The interval in microseconds to check if the request is cancelled, and to send progress info.", 0) \
M(SettingSeconds, connect_timeout, DBMS_DEFAULT_CONNECT_TIMEOUT_SEC, "Connection timeout if there are no replicas.", 0) \
M(SettingMilliseconds, connect_timeout_with_failover_ms, DBMS_DEFAULT_CONNECT_TIMEOUT_WITH_FAILOVER_MS, "Connection timeout for selecting first healthy replica.", 0) \
M(SettingMilliseconds, connect_timeout_with_failover_secure_ms, DBMS_DEFAULT_CONNECT_TIMEOUT_WITH_FAILOVER_SECURE_MS, "Connection timeout for selecting first healthy replica (for secure connections).", 0) \
M(SettingSeconds, receive_timeout, DBMS_DEFAULT_RECEIVE_TIMEOUT_SEC, "", 0) \
M(SettingSeconds, send_timeout, DBMS_DEFAULT_SEND_TIMEOUT_SEC, "", 0) \
M(SettingSeconds, tcp_keep_alive_timeout, 0, "The time in seconds the connection needs to remain idle before TCP starts sending keepalive probes", 0) \
@ -203,6 +204,7 @@ struct Settings : public SettingsCollection<Settings>
M(SettingUInt64, output_format_parquet_row_group_size, 1000000, "Row group size in rows.", 0) \
M(SettingString, output_format_avro_codec, "", "Compression codec used for output. Possible values: 'null', 'deflate', 'snappy'.", 0) \
M(SettingUInt64, output_format_avro_sync_interval, 16 * 1024, "Sync interval in bytes.", 0) \
M(SettingBool, output_format_tsv_crlf_end_of_line, false, "If it is set true, end of line in TSV format will be \\r\\n instead of \\n.", 0) \
\
M(SettingBool, use_client_time_zone, false, "Use client timezone for interpreting DateTime string values, instead of adopting server timezone.", 0) \
\
@ -330,6 +332,7 @@ struct Settings : public SettingsCollection<Settings>
M(SettingUInt64, max_memory_usage, 0, "Maximum memory usage for processing of single query. Zero means unlimited.", 0) \
M(SettingUInt64, max_memory_usage_for_user, 0, "Maximum memory usage for processing all concurrently running queries for the user. Zero means unlimited.", 0) \
M(SettingUInt64, max_memory_usage_for_all_queries, 0, "Maximum memory usage for processing all concurrently running queries on the server. Zero means unlimited.", 0) \
M(SettingUInt64, memory_profiler_step, 0, "Every number of bytes the memory profiler will dump the allocating stacktrace. Zero means disabled memory profiler.", 0) \
\
M(SettingUInt64, max_network_bandwidth, 0, "The maximum speed of data exchange over the network in bytes per second for a query. Zero means unlimited.", 0) \
M(SettingUInt64, max_network_bytes, 0, "The maximum number of bytes (compressed) to receive or transmit over the network for execution of the query.", 0) \
@ -338,7 +341,7 @@ struct Settings : public SettingsCollection<Settings>
M(SettingChar, format_csv_delimiter, ',', "The character to be considered as a delimiter in CSV data. If setting with a string, a string has to have a length of 1.", 0) \
M(SettingBool, format_csv_allow_single_quotes, 1, "If it is set to true, allow strings in single quotes.", 0) \
M(SettingBool, format_csv_allow_double_quotes, 1, "If it is set to true, allow strings in double quotes.", 0) \
M(SettingBool, output_format_csv_crlf_end_of_line, false, "If it is set true, end of line will be \\r\\n instead of \\n.", 0) \
M(SettingBool, output_format_csv_crlf_end_of_line, false, "If it is set true, end of line in CSV format will be \\r\\n instead of \\n.", 0) \
M(SettingBool, input_format_csv_unquoted_null_literal_as_null, false, "Consider unquoted NULL literal as \\N", 0) \
\
M(SettingDateTimeInputFormat, date_time_input_format, FormatSettings::DateTimeInputFormat::Basic, "Method to read DateTime from text input formats. Possible values: 'basic' and 'best_effort'.", 0) \
@ -390,6 +393,9 @@ struct Settings : public SettingsCollection<Settings>
M(SettingUInt64, mutations_sync, 0, "Wait for synchronous execution of ALTER TABLE UPDATE/DELETE queries (mutations). 0 - execute asynchronously. 1 - wait current server. 2 - wait all replicas if they exist.", 0) \
M(SettingBool, optimize_if_chain_to_miltiif, false, "Replace if(cond1, then1, if(cond2, ...)) chains to multiIf. Currently it's not beneficial for numeric types.", 0) \
M(SettingBool, allow_experimental_alter_materialized_view_structure, false, "Allow atomic alter on Materialized views. Work in progress.", 0) \
M(SettingBool, enable_early_constant_folding, true, "Enable query optimization where we analyze function and subqueries results and rewrite query if there're constants there", 0) \
\
M(SettingBool, partial_revokes, false, "Makes it possible to revoke privileges partially.", 0) \
\
/** Obsolete settings that do nothing but left for compatibility reasons. Remove each one after half a year of obsolescence. */ \
\

View File

@ -421,14 +421,7 @@ void SettingURI::set(const Field & x)
void SettingURI::set(const String & x)
{
try {
Poco::URI uri(x);
set(uri);
}
catch (const Poco::Exception& e)
{
throw Exception{Exception::CreateFromPoco, e};
}
set(Poco::URI(x));
}
void SettingURI::serialize(WriteBuffer & buf, SettingsBinaryFormat) const

View File

@ -14,6 +14,7 @@ namespace DB
struct Null {};
/// @note Except explicitly described you should not assume on TypeIndex numbers and/or their orders in this enum.
enum class TypeIndex
{
Nothing = 0,

View File

@ -1,6 +1,7 @@
#include <DataStreams/ExpressionBlockInputStream.h>
#include <DataStreams/CheckConstraintsBlockOutputStream.h>
#include <Parsers/formatAST.h>
#include <Interpreters/ExpressionActions.h>
#include <Columns/ColumnsCommon.h>
#include <Columns/ColumnsNumber.h>
#include <Common/assert_cast.h>

View File

@ -34,7 +34,7 @@ void ParallelParsingBlockInputStream::segmentatorThreadFunction()
unit.is_last = !have_more_data;
unit.status = READY_TO_PARSE;
scheduleParserThreadForUnitWithNumber(current_unit_number);
scheduleParserThreadForUnitWithNumber(segmentator_ticket_number);
++segmentator_ticket_number;
if (!have_more_data)
@ -49,12 +49,13 @@ void ParallelParsingBlockInputStream::segmentatorThreadFunction()
}
}
void ParallelParsingBlockInputStream::parserThreadFunction(size_t current_unit_number)
void ParallelParsingBlockInputStream::parserThreadFunction(size_t current_ticket_number)
{
try
{
setThreadName("ChunkParser");
const auto current_unit_number = current_ticket_number % processing_units.size();
auto & unit = processing_units[current_unit_number];
/*
@ -64,9 +65,9 @@ void ParallelParsingBlockInputStream::parserThreadFunction(size_t current_unit_n
* can use it from multiple threads simultaneously.
*/
ReadBuffer read_buffer(unit.segment.data(), unit.segment.size(), 0);
auto parser = std::make_unique<InputStreamFromInputFormat>(
input_processor_creator(read_buffer, header,
row_input_format_params, format_settings));
auto format = input_processor_creator(read_buffer, header, row_input_format_params, format_settings);
format->setCurrentUnitNumber(current_ticket_number);
auto parser = std::make_unique<InputStreamFromInputFormat>(std::move(format));
unit.block_ext.block.clear();
unit.block_ext.block_missing_values.clear();

View File

@ -213,9 +213,9 @@ private:
std::deque<ProcessingUnit> processing_units;
void scheduleParserThreadForUnitWithNumber(size_t unit_number)
void scheduleParserThreadForUnitWithNumber(size_t ticket_number)
{
pool.scheduleOrThrowOnError(std::bind(&ParallelParsingBlockInputStream::parserThreadFunction, this, unit_number));
pool.scheduleOrThrowOnError(std::bind(&ParallelParsingBlockInputStream::parserThreadFunction, this, ticket_number));
}
void finishAndWait()

View File

@ -8,7 +8,6 @@
#include <Parsers/ASTInsertQuery.h>
#include <Common/CurrentThread.h>
#include <Common/setThreadName.h>
#include <Common/getNumberOfPhysicalCPUCores.h>
#include <Common/ThreadPool.h>
#include <Storages/MergeTree/ReplicatedMergeTreeBlockOutputStream.h>
#include <Storages/StorageValues.h>
@ -51,8 +50,10 @@ PushingToViewsBlockOutputStream::PushingToViewsBlockOutputStream(
ASTPtr query;
BlockOutputStreamPtr out;
if (auto * materialized_view = dynamic_cast<const StorageMaterializedView *>(dependent_table.get()))
if (auto * materialized_view = dynamic_cast<StorageMaterializedView *>(dependent_table.get()))
{
addTableLock(materialized_view->lockStructureForShare(true, context.getInitialQueryId()));
StoragePtr inner_table = materialized_view->getTargetTable();
auto inner_table_id = inner_table->getStorageID();
query = materialized_view->getInnerQuery();

View File

@ -621,6 +621,12 @@ inline bool isStringOrFixedString(const T & data_type)
return WhichDataType(data_type).isStringOrFixedString();
}
template <typename T>
inline bool isNotCreatable(const T & data_type)
{
WhichDataType which(data_type);
return which.isNothing() || which.isFunction() || which.isSet();
}
inline bool isNotDecimalButComparableToDecimal(const DataTypePtr & data_type)
{

View File

@ -252,20 +252,23 @@ void DatabaseOrdinary::alterTable(
ast->replace(ast_create_query.select, metadata.select);
}
ASTStorage & storage_ast = *ast_create_query.storage;
/// ORDER BY may change, but cannot appear, it's required construction
if (metadata.order_by_ast && storage_ast.order_by)
storage_ast.set(storage_ast.order_by, metadata.order_by_ast);
/// MaterializedView is one type of CREATE query without storage.
if (ast_create_query.storage)
{
ASTStorage & storage_ast = *ast_create_query.storage;
/// ORDER BY may change, but cannot appear, it's required construction
if (metadata.order_by_ast && storage_ast.order_by)
storage_ast.set(storage_ast.order_by, metadata.order_by_ast);
if (metadata.primary_key_ast)
storage_ast.set(storage_ast.primary_key, metadata.primary_key_ast);
if (metadata.primary_key_ast)
storage_ast.set(storage_ast.primary_key, metadata.primary_key_ast);
if (metadata.ttl_for_table_ast)
storage_ast.set(storage_ast.ttl_table, metadata.ttl_for_table_ast);
if (metadata.settings_ast)
storage_ast.set(storage_ast.settings, metadata.settings_ast);
if (metadata.ttl_for_table_ast)
storage_ast.set(storage_ast.ttl_table, metadata.ttl_for_table_ast);
if (metadata.settings_ast)
storage_ast.set(storage_ast.settings, metadata.settings_ast);
}
statement = getObjectDefinitionFromCreateQuery(ast);
{

View File

@ -12,6 +12,7 @@
#include <Common/typeid_cast.h>
#include <ext/range.h>
#include <ext/size.h>
#include <Common/setThreadName.h>
#include "CacheDictionary.inc.h"
#include "DictionaryBlockInputStream.h"
#include "DictionaryFactory.h"
@ -61,24 +62,48 @@ CacheDictionary::CacheDictionary(
const std::string & name_,
const DictionaryStructure & dict_struct_,
DictionarySourcePtr source_ptr_,
const DictionaryLifetime dict_lifetime_,
const size_t size_)
DictionaryLifetime dict_lifetime_,
size_t size_,
bool allow_read_expired_keys_,
size_t max_update_queue_size_,
size_t update_queue_push_timeout_milliseconds_,
size_t max_threads_for_updates_)
: database(database_)
, name(name_)
, full_name{database_.empty() ? name_ : (database_ + "." + name_)}
, dict_struct(dict_struct_)
, source_ptr{std::move(source_ptr_)}
, dict_lifetime(dict_lifetime_)
, allow_read_expired_keys(allow_read_expired_keys_)
, max_update_queue_size(max_update_queue_size_)
, update_queue_push_timeout_milliseconds(update_queue_push_timeout_milliseconds_)
, max_threads_for_updates(max_threads_for_updates_)
, log(&Logger::get("ExternalDictionaries"))
, size{roundUpToPowerOfTwoOrZero(std::max(size_, size_t(max_collision_length)))}
, size_overlap_mask{this->size - 1}
, cells{this->size}
, rnd_engine(randomSeed())
, update_queue(max_update_queue_size_)
, update_pool(max_threads_for_updates)
{
if (!this->source_ptr->supportsSelectiveLoad())
throw Exception{full_name + ": source cannot be used with CacheDictionary", ErrorCodes::UNSUPPORTED_METHOD};
createAttributes();
for (size_t i = 0; i < max_threads_for_updates; ++i)
update_pool.scheduleOrThrowOnError([this] { updateThreadFunction(); });
}
CacheDictionary::~CacheDictionary()
{
finished = true;
update_queue.clear();
for (size_t i = 0; i < max_threads_for_updates; ++i)
{
auto empty_finishing_ptr = std::make_shared<UpdateUnit>(std::vector<Key>());
update_queue.push(empty_finishing_ptr);
}
update_pool.wait();
}
@ -275,10 +300,16 @@ CacheDictionary::FindResult CacheDictionary::findCellIdx(const Key & id, const C
void CacheDictionary::has(const PaddedPODArray<Key> & ids, PaddedPODArray<UInt8> & out) const
{
/// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
std::unordered_map<Key, std::vector<size_t>> outdated_ids;
/// There are three types of ids.
/// - Valid ids. These ids are presented in local cache and their lifetime is not expired.
/// - CacheExpired ids. Ids that are in local cache, but their values are rotted (lifetime is expired).
/// - CacheNotFound ids. We have to go to external storage to know its value.
size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0;
/// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
std::unordered_map<Key, std::vector<size_t>> cache_expired_ids;
std::unordered_map<Key, std::vector<size_t>> cache_not_found_ids;
size_t cache_hit = 0;
const auto rows = ext::size(ids);
{
@ -291,49 +322,97 @@ void CacheDictionary::has(const PaddedPODArray<Key> & ids, PaddedPODArray<UInt8>
const auto id = ids[row];
const auto find_result = findCellIdx(id, now);
const auto & cell_idx = find_result.cell_idx;
auto insert_to_answer_routine = [&] ()
{
out[row] = !cells[cell_idx].isDefault();
};
if (!find_result.valid)
{
outdated_ids[id].push_back(row);
if (find_result.outdated)
++cache_expired;
{
cache_expired_ids[id].push_back(row);
if (allow_read_expired_keys)
insert_to_answer_routine();
}
else
++cache_not_found;
{
cache_not_found_ids[id].push_back(row);
}
}
else
{
++cache_hit;
const auto & cell = cells[cell_idx];
out[row] = !cell.isDefault();
insert_to_answer_routine();
}
}
}
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired);
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found);
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired_ids.size());
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found_ids.size());
ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit);
query_count.fetch_add(rows, std::memory_order_relaxed);
hit_count.fetch_add(rows - outdated_ids.size(), std::memory_order_release);
hit_count.fetch_add(rows - cache_expired_ids.size() - cache_not_found_ids.size(), std::memory_order_release);
if (outdated_ids.empty())
return;
if (cache_not_found_ids.empty())
{
/// Nothing to update - return;
if (cache_expired_ids.empty())
return;
std::vector<Key> required_ids(outdated_ids.size());
std::transform(std::begin(outdated_ids), std::end(outdated_ids), std::begin(required_ids), [](auto & pair) { return pair.first; });
/// request new values
update(
required_ids,
[&](const auto id, const auto)
if (allow_read_expired_keys)
{
for (const auto row : outdated_ids[id])
out[row] = true;
},
[&](const auto id, const auto)
{
for (const auto row : outdated_ids[id])
out[row] = false;
});
std::vector<Key> required_expired_ids;
required_expired_ids.reserve(cache_expired_ids.size());
std::transform(
std::begin(cache_expired_ids), std::end(cache_expired_ids),
std::back_inserter(required_expired_ids), [](auto & pair) { return pair.first; });
/// Callbacks are empty because we don't want to receive them after an unknown period of time.
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_expired_ids);
tryPushToUpdateQueueOrThrow(update_unit_ptr);
/// Update is async - no need to wait.
return;
}
}
/// At this point we have two situations.
/// There may be both types of keys: cache_expired_ids and cache_not_found_ids.
/// We will update them all synchronously.
std::vector<Key> required_ids;
required_ids.reserve(cache_not_found_ids.size() + cache_expired_ids.size());
std::transform(
std::begin(cache_not_found_ids), std::end(cache_not_found_ids),
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
std::transform(
std::begin(cache_expired_ids), std::end(cache_expired_ids),
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
auto on_cell_updated = [&] (const Key id, const size_t)
{
for (const auto row : cache_not_found_ids[id])
out[row] = true;
for (const auto row : cache_expired_ids[id])
out[row] = true;
};
auto on_id_not_found = [&] (const Key id, const size_t)
{
for (const auto row : cache_not_found_ids[id])
out[row] = false;
for (const auto row : cache_expired_ids[id])
out[row] = true;
};
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_ids, on_cell_updated, on_id_not_found);
tryPushToUpdateQueueOrThrow(update_unit_ptr);
waitForCurrentUpdateFinish(update_unit_ptr);
}
@ -590,7 +669,8 @@ void registerDictionaryCache(DictionaryFactory & factory)
DictionarySourcePtr source_ptr) -> DictionaryPtr
{
if (dict_struct.key)
throw Exception{"'key' is not supported for dictionary of layout 'cache'", ErrorCodes::UNSUPPORTED_METHOD};
throw Exception{"'key' is not supported for dictionary of layout 'cache'",
ErrorCodes::UNSUPPORTED_METHOD};
if (dict_struct.range_min || dict_struct.range_max)
throw Exception{full_name
@ -598,9 +678,11 @@ void registerDictionaryCache(DictionaryFactory & factory)
"for a dictionary of layout 'range_hashed'",
ErrorCodes::BAD_ARGUMENTS};
const auto & layout_prefix = config_prefix + ".layout";
const auto size = config.getInt(layout_prefix + ".cache.size_in_cells");
const size_t size = config.getUInt64(layout_prefix + ".cache.size_in_cells");
if (size == 0)
throw Exception{full_name + ": dictionary of layout 'cache' cannot have 0 cells", ErrorCodes::TOO_SMALL_BUFFER_SIZE};
throw Exception{full_name + ": dictionary of layout 'cache' cannot have 0 cells",
ErrorCodes::TOO_SMALL_BUFFER_SIZE};
const bool require_nonempty = config.getBool(config_prefix + ".require_nonempty", false);
if (require_nonempty)
@ -610,10 +692,284 @@ void registerDictionaryCache(DictionaryFactory & factory)
const String database = config.getString(config_prefix + ".database", "");
const String name = config.getString(config_prefix + ".name");
const DictionaryLifetime dict_lifetime{config, config_prefix + ".lifetime"};
return std::make_unique<CacheDictionary>(database, name, dict_struct, std::move(source_ptr), dict_lifetime, size);
const size_t max_update_queue_size =
config.getUInt64(layout_prefix + ".cache.max_update_queue_size", 100000);
if (max_update_queue_size == 0)
throw Exception{name + ": dictionary of layout 'cache' cannot have empty update queue of size 0",
ErrorCodes::TOO_SMALL_BUFFER_SIZE};
const bool allow_read_expired_keys =
config.getBool(layout_prefix + ".cache.allow_read_expired_keys", false);
const size_t update_queue_push_timeout_milliseconds =
config.getUInt64(layout_prefix + ".cache.update_queue_push_timeout_milliseconds", 10);
if (update_queue_push_timeout_milliseconds < 10)
throw Exception{name + ": dictionary of layout 'cache' have too little update_queue_push_timeout",
ErrorCodes::BAD_ARGUMENTS};
const size_t max_threads_for_updates =
config.getUInt64(layout_prefix + ".max_threads_for_updates", 4);
if (max_threads_for_updates == 0)
throw Exception{name + ": dictionary of layout 'cache' cannot have zero threads for updates.",
ErrorCodes::BAD_ARGUMENTS};
return std::make_unique<CacheDictionary>(
database, name, dict_struct, std::move(source_ptr), dict_lifetime, size,
allow_read_expired_keys, max_update_queue_size, update_queue_push_timeout_milliseconds,
max_threads_for_updates);
};
factory.registerLayout("cache", create_layout, false);
}
void CacheDictionary::updateThreadFunction()
{
setThreadName("AsyncUpdater");
while (!finished)
{
UpdateUnitPtr first_popped;
update_queue.pop(first_popped);
if (finished)
break;
/// Here we pop as many unit pointers from update queue as we can.
/// We fix current size to avoid livelock (or too long waiting),
/// when this thread pops from the queue and other threads push to the queue.
const size_t current_queue_size = update_queue.size();
if (current_queue_size > 0)
LOG_TRACE(log, "Performing bunch of keys update in cache dictionary with "
<< current_queue_size + 1 << " keys");
std::vector<UpdateUnitPtr> update_request;
update_request.reserve(current_queue_size + 1);
update_request.emplace_back(first_popped);
UpdateUnitPtr current_unit_ptr;
while (update_request.size() && update_queue.tryPop(current_unit_ptr))
update_request.emplace_back(std::move(current_unit_ptr));
BunchUpdateUnit bunch_update_unit(update_request);
try
{
/// Update a bunch of ids.
update(bunch_update_unit);
/// Notify all threads about finished updating the bunch of ids
/// where their own ids were included.
std::unique_lock<std::mutex> lock(update_mutex);
for (auto & unit_ptr: update_request)
unit_ptr->is_done = true;
is_update_finished.notify_all();
}
catch (...)
{
std::unique_lock<std::mutex> lock(update_mutex);
/// It is a big trouble, because one bad query can make other threads fail with not relative exception.
/// So at this point all threads (and queries) will receive the same exception.
for (auto & unit_ptr: update_request)
unit_ptr->current_exception = std::current_exception();
is_update_finished.notify_all();
}
}
}
void CacheDictionary::waitForCurrentUpdateFinish(UpdateUnitPtr & update_unit_ptr) const
{
std::unique_lock<std::mutex> lock(update_mutex);
/*
* We wait here without any timeout to avoid SEGFAULT's.
* Consider timeout for wait had expired and main query's thread ended with exception
* or some other error. But the UpdateUnit with callbacks is left in the queue.
* It has these callback that capture god knows what from the current thread
* (most of the variables lies on the stack of finished thread) that
* intended to do a synchronous update. AsyncUpdate thread can touch deallocated memory and explode.
* */
is_update_finished.wait(
lock,
[&] {return update_unit_ptr->is_done || update_unit_ptr->current_exception; });
if (update_unit_ptr->current_exception)
std::rethrow_exception(update_unit_ptr->current_exception);
}
void CacheDictionary::tryPushToUpdateQueueOrThrow(UpdateUnitPtr & update_unit_ptr) const
{
if (!update_queue.tryPush(update_unit_ptr, update_queue_push_timeout_milliseconds))
throw DB::Exception(
"Cannot push to internal update queue in dictionary " + getFullName() + ". Timelimit of " +
std::to_string(update_queue_push_timeout_milliseconds) + " ms. exceeded. Current queue size is " +
std::to_string(update_queue.size()), ErrorCodes::CACHE_DICTIONARY_UPDATE_FAIL);
}
void CacheDictionary::update(BunchUpdateUnit & bunch_update_unit) const
{
CurrentMetrics::Increment metric_increment{CurrentMetrics::DictCacheRequests};
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequested, bunch_update_unit.getRequestedIds().size());
std::unordered_map<Key, UInt8> remaining_ids{bunch_update_unit.getRequestedIds().size()};
for (const auto id : bunch_update_unit.getRequestedIds())
remaining_ids.insert({id, 0});
const auto now = std::chrono::system_clock::now();
if (now > backoff_end_time.load())
{
try
{
if (error_count)
{
/// Recover after error: we have to clone the source here because
/// it could keep connections which should be reset after error.
source_ptr = source_ptr->clone();
}
Stopwatch watch;
auto stream = source_ptr->loadIds(bunch_update_unit.getRequestedIds());
const ProfilingScopedWriteRWLock write_lock{rw_lock, ProfileEvents::DictCacheLockWriteNs};
stream->readPrefix();
while (const auto block = stream->read())
{
const auto id_column = typeid_cast<const ColumnUInt64 *>(block.safeGetByPosition(0).column.get());
if (!id_column)
throw Exception{name + ": id column has type different from UInt64.", ErrorCodes::TYPE_MISMATCH};
const auto & ids = id_column->getData();
/// cache column pointers
const auto column_ptrs = ext::map<std::vector>(
ext::range(0, attributes.size()), [&block](size_t i) { return block.safeGetByPosition(i + 1).column.get(); });
for (const auto i : ext::range(0, ids.size()))
{
const auto id = ids[i];
const auto find_result = findCellIdx(id, now);
const auto & cell_idx = find_result.cell_idx;
auto & cell = cells[cell_idx];
for (const auto attribute_idx : ext::range(0, attributes.size()))
{
const auto & attribute_column = *column_ptrs[attribute_idx];
auto & attribute = attributes[attribute_idx];
setAttributeValue(attribute, cell_idx, attribute_column[i]);
}
/// if cell id is zero and zero does not map to this cell, then the cell is unused
if (cell.id == 0 && cell_idx != zero_cell_idx)
element_count.fetch_add(1, std::memory_order_relaxed);
cell.id = id;
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
{
std::uniform_int_distribution<UInt64> distribution{dict_lifetime.min_sec, dict_lifetime.max_sec};
cell.setExpiresAt(now + std::chrono::seconds{distribution(rnd_engine)});
}
else
cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
bunch_update_unit.informCallersAboutPresentId(id, cell_idx);
/// mark corresponding id as found
remaining_ids[id] = 1;
}
}
stream->readSuffix();
error_count = 0;
last_exception = std::exception_ptr{};
backoff_end_time = std::chrono::system_clock::time_point{};
ProfileEvents::increment(ProfileEvents::DictCacheRequestTimeNs, watch.elapsed());
}
catch (...)
{
++error_count;
last_exception = std::current_exception();
backoff_end_time = now + std::chrono::seconds(calculateDurationWithBackoff(rnd_engine, error_count));
tryLogException(last_exception, log, "Could not update cache dictionary '" + getFullName() +
"', next update is scheduled at " + ext::to_string(backoff_end_time.load()));
}
}
size_t not_found_num = 0, found_num = 0;
const ProfilingScopedWriteRWLock write_lock{rw_lock, ProfileEvents::DictCacheLockWriteNs};
/// Check which ids have not been found and require setting null_value
for (const auto & id_found_pair : remaining_ids)
{
if (id_found_pair.second)
{
++found_num;
continue;
}
++not_found_num;
const auto id = id_found_pair.first;
const auto find_result = findCellIdx(id, now);
const auto & cell_idx = find_result.cell_idx;
auto & cell = cells[cell_idx];
if (error_count)
{
if (find_result.outdated)
{
/// We have expired data for that `id` so we can continue using it.
bool was_default = cell.isDefault();
cell.setExpiresAt(backoff_end_time);
if (was_default)
cell.setDefault();
if (was_default)
bunch_update_unit.informCallersAboutAbsentId(id, cell_idx);
else
bunch_update_unit.informCallersAboutPresentId(id, cell_idx);
continue;
}
/// We don't have expired data for that `id` so all we can do is to rethrow `last_exception`.
std::rethrow_exception(last_exception);
}
/// Check if cell had not been occupied before and increment element counter if it hadn't
if (cell.id == 0 && cell_idx != zero_cell_idx)
element_count.fetch_add(1, std::memory_order_relaxed);
cell.id = id;
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
{
std::uniform_int_distribution<UInt64> distribution{dict_lifetime.min_sec, dict_lifetime.max_sec};
cell.setExpiresAt(now + std::chrono::seconds{distribution(rnd_engine)});
}
else
cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
/// Set null_value for each attribute
cell.setDefault();
for (auto & attribute : attributes)
setDefaultAttributeValue(attribute, cell_idx);
/// inform caller that the cell has not been found
bunch_update_unit.informCallersAboutAbsentId(id, cell_idx);
}
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequestedMiss, not_found_num);
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequestedFound, found_num);
ProfileEvents::increment(ProfileEvents::DictCacheRequests);
}
}

View File

@ -4,12 +4,16 @@
#include <chrono>
#include <cmath>
#include <map>
#include <mutex>
#include <shared_mutex>
#include <utility>
#include <variant>
#include <vector>
#include <common/logger_useful.h>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnString.h>
#include <Common/ThreadPool.h>
#include <Common/ConcurrentBoundedQueue.h>
#include <pcg_random.hpp>
#include <Common/ArenaWithFreeLists.h>
#include <Common/CurrentMetrics.h>
@ -21,6 +25,22 @@
namespace DB
{
namespace ErrorCodes
{
extern const int CACHE_DICTIONARY_UPDATE_FAIL;
}
/*
*
* This dictionary is stored in a cache that has a fixed number of cells.
* These cells contain frequently used elements.
* When searching for a dictionary, the cache is searched first and special heuristic is used:
* while looking for the key, we take a look only at max_collision_length elements.
* So, our cache is not perfect. It has errors like "the key is in cache, but the cache says that it does not".
* And in this case we simply ask external source for the key which is faster.
* You have to keep this logic in mind.
* */
class CacheDictionary final : public IDictionary
{
public:
@ -29,8 +49,14 @@ public:
const std::string & name_,
const DictionaryStructure & dict_struct_,
DictionarySourcePtr source_ptr_,
const DictionaryLifetime dict_lifetime_,
const size_t size_);
DictionaryLifetime dict_lifetime_,
size_t size_,
bool allow_read_expired_keys_,
size_t max_update_queue_size_,
size_t update_queue_push_timeout_milliseconds_,
size_t max_threads_for_updates);
~CacheDictionary() override;
const std::string & getDatabase() const override { return database; }
const std::string & getName() const override { return name; }
@ -55,7 +81,10 @@ public:
std::shared_ptr<const IExternalLoadable> clone() const override
{
return std::make_shared<CacheDictionary>(database, name, dict_struct, source_ptr->clone(), dict_lifetime, size);
return std::make_shared<CacheDictionary>(
database, name, dict_struct, source_ptr->clone(), dict_lifetime, size,
allow_read_expired_keys, max_update_queue_size,
update_queue_push_timeout_milliseconds, max_threads_for_updates);
}
const IDictionarySource * getSource() const override { return source_ptr.get(); }
@ -230,9 +259,6 @@ private:
template <typename DefaultGetter>
void getItemsString(Attribute & attribute, const PaddedPODArray<Key> & ids, ColumnString * out, DefaultGetter && get_default) const;
template <typename PresentIdHandler, typename AbsentIdHandler>
void update(const std::vector<Key> & requested_ids, PresentIdHandler && on_cell_updated, AbsentIdHandler && on_id_not_found) const;
PaddedPODArray<Key> getCachedIds() const;
bool isEmptyCell(const UInt64 idx) const;
@ -263,6 +289,11 @@ private:
const DictionaryStructure dict_struct;
mutable DictionarySourcePtr source_ptr;
const DictionaryLifetime dict_lifetime;
const bool allow_read_expired_keys;
const size_t max_update_queue_size;
const size_t update_queue_push_timeout_milliseconds;
const size_t max_threads_for_updates;
Logger * const log;
mutable std::shared_mutex rw_lock;
@ -284,8 +315,8 @@ private:
std::unique_ptr<ArenaWithFreeLists> string_arena;
mutable std::exception_ptr last_exception;
mutable size_t error_count = 0;
mutable std::chrono::system_clock::time_point backoff_end_time;
mutable std::atomic<size_t> error_count = 0;
mutable std::atomic<std::chrono::system_clock::time_point> backoff_end_time{std::chrono::system_clock::time_point{}};
mutable pcg64 rnd_engine;
@ -293,6 +324,166 @@ private:
mutable std::atomic<size_t> element_count{0};
mutable std::atomic<size_t> hit_count{0};
mutable std::atomic<size_t> query_count{0};
};
/// Field and methods correlated with update expired and not found keys
using PresentIdHandler = std::function<void(Key, size_t)>;
using AbsentIdHandler = std::function<void(Key, size_t)>;
/*
* Disclaimer: this comment is written not for fun.
*
* How the update goes: we basically have a method like get(keys)->values. Values are cached, so sometimes we
* can return them from the cache. For values not in cache, we query them from the dictionary, and add to the
* cache. The cache is lossy, so we can't expect it to store all the keys, and we store them separately. Normally,
* they would be passed as a return value of get(), but for Unknown Reasons the dictionaries use a baroque
* interface where get() accepts two callback, one that it calls for found values, and one for not found.
*
* Now we make it even uglier by doing this from multiple threads. The missing values are retreived from the
* dictionary in a background thread, and this thread calls the provided callback. So if you provide the callbacks,
* you MUST wait until the background update finishes, or god knows what happens. Unfortunately, we have no
* way to check that you did this right, so good luck.
*/
struct UpdateUnit
{
UpdateUnit(std::vector<Key> requested_ids_,
PresentIdHandler present_id_handler_,
AbsentIdHandler absent_id_handler_) :
requested_ids(std::move(requested_ids_)),
present_id_handler(present_id_handler_),
absent_id_handler(absent_id_handler_) {}
explicit UpdateUnit(std::vector<Key> requested_ids_) :
requested_ids(std::move(requested_ids_)),
present_id_handler([](Key, size_t){}),
absent_id_handler([](Key, size_t){}) {}
std::vector<Key> requested_ids;
PresentIdHandler present_id_handler;
AbsentIdHandler absent_id_handler;
std::atomic<bool> is_done{false};
std::exception_ptr current_exception{nullptr};
};
using UpdateUnitPtr = std::shared_ptr<UpdateUnit>;
using UpdateQueue = ConcurrentBoundedQueue<UpdateUnitPtr>;
/*
* This class is used to concatenate requested_keys.
*
* Imagine that we have several UpdateUnit with different vectors of keys and callbacks for that keys.
* We concatenate them into a long vector of keys that looks like:
*
* a1...ak_a b1...bk_2 c1...ck_3,
*
* where a1...ak_a are requested_keys from the first UpdateUnit.
* In addition we have the same number (three in this case) of callbacks.
* This class helps us to find a callback (or many callbacks) for a special key.
* */
class BunchUpdateUnit
{
public:
explicit BunchUpdateUnit(std::vector<UpdateUnitPtr> & update_request)
{
/// Here we prepare total count of all requested ids
/// not to do useless allocations later.
size_t total_requested_keys_count = 0;
for (auto & unit_ptr: update_request)
{
total_requested_keys_count += unit_ptr->requested_ids.size();
if (helper.empty())
helper.push_back(unit_ptr->requested_ids.size());
else
helper.push_back(unit_ptr->requested_ids.size() + helper.back());
present_id_handlers.emplace_back(unit_ptr->present_id_handler);
absent_id_handlers.emplace_back(unit_ptr->absent_id_handler);
}
concatenated_requested_ids.reserve(total_requested_keys_count);
for (auto & unit_ptr: update_request)
std::for_each(std::begin(unit_ptr->requested_ids), std::end(unit_ptr->requested_ids),
[&] (const Key & key) {concatenated_requested_ids.push_back(key);});
}
const std::vector<Key> & getRequestedIds()
{
return concatenated_requested_ids;
}
void informCallersAboutPresentId(Key id, size_t cell_idx)
{
for (size_t i = 0; i < concatenated_requested_ids.size(); ++i)
{
auto & curr = concatenated_requested_ids[i];
if (curr == id)
getPresentIdHandlerForPosition(i)(id, cell_idx);
}
}
void informCallersAboutAbsentId(Key id, size_t cell_idx)
{
for (size_t i = 0; i < concatenated_requested_ids.size(); ++i)
if (concatenated_requested_ids[i] == id)
getAbsentIdHandlerForPosition(i)(id, cell_idx);
}
private:
PresentIdHandler & getPresentIdHandlerForPosition(size_t position)
{
return present_id_handlers[getUpdateUnitNumberForRequestedIdPosition(position)];
}
AbsentIdHandler & getAbsentIdHandlerForPosition(size_t position)
{
return absent_id_handlers[getUpdateUnitNumberForRequestedIdPosition((position))];
}
size_t getUpdateUnitNumberForRequestedIdPosition(size_t position)
{
return std::lower_bound(helper.begin(), helper.end(), position) - helper.begin();
}
std::vector<Key> concatenated_requested_ids;
std::vector<PresentIdHandler> present_id_handlers;
std::vector<AbsentIdHandler> absent_id_handlers;
std::vector<size_t> helper;
};
mutable UpdateQueue update_queue;
ThreadPool update_pool;
/*
* Actually, we can divide all requested keys into two 'buckets'. There are only four possible states and they
* are described in the table.
*
* cache_not_found_ids |0|0|1|1|
* cache_expired_ids |0|1|0|1|
*
* 0 - if set is empty, 1 - otherwise
*
* Only if there are no cache_not_found_ids and some cache_expired_ids
* (with allow_read_expired_keys_from_cache_dictionary setting) we can perform async update.
* Otherwise we have no concatenate ids and update them sync.
*
*/
void updateThreadFunction();
void update(BunchUpdateUnit & bunch_update_unit) const;
void tryPushToUpdateQueueOrThrow(UpdateUnitPtr & update_unit_ptr) const;
void waitForCurrentUpdateFinish(UpdateUnitPtr & update_unit_ptr) const;
mutable std::mutex update_mutex;
mutable std::condition_variable is_update_finished;
std::atomic<bool> finished{false};
};
}

View File

@ -40,11 +40,13 @@ void CacheDictionary::getItemsNumberImpl(
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const
{
/// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
std::unordered_map<Key, std::vector<size_t>> outdated_ids;
std::unordered_map<Key, std::vector<size_t>> cache_expired_ids;
std::unordered_map<Key, std::vector<size_t>> cache_not_found_ids;
auto & attribute_array = std::get<ContainerPtrType<AttributeType>>(attribute.arrays);
const auto rows = ext::size(ids);
size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0;
size_t cache_hit = 0;
{
const ProfilingScopedReadRWLock read_lock{rw_lock, ProfileEvents::DictCacheLockReadNs};
@ -61,52 +63,105 @@ void CacheDictionary::getItemsNumberImpl(
* 3. explicit defaults were specified and cell was set default. */
const auto find_result = findCellIdx(id, now);
auto update_routine = [&]()
{
const auto & cell_idx = find_result.cell_idx;
const auto & cell = cells[cell_idx];
out[row] = cell.isDefault() ? get_default(row) : static_cast<OutputType>(attribute_array[cell_idx]);
};
if (!find_result.valid)
{
outdated_ids[id].push_back(row);
if (find_result.outdated)
++cache_expired;
{
cache_expired_ids[id].push_back(row);
if (allow_read_expired_keys)
update_routine();
}
else
++cache_not_found;
{
cache_not_found_ids[id].push_back(row);
}
}
else
{
++cache_hit;
const auto & cell_idx = find_result.cell_idx;
const auto & cell = cells[cell_idx];
out[row] = cell.isDefault() ? get_default(row) : static_cast<OutputType>(attribute_array[cell_idx]);
update_routine();
}
}
}
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired);
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found);
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired_ids.size());
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found_ids.size());
ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit);
query_count.fetch_add(rows, std::memory_order_relaxed);
hit_count.fetch_add(rows - outdated_ids.size(), std::memory_order_release);
hit_count.fetch_add(rows - cache_expired_ids.size() - cache_not_found_ids.size(), std::memory_order_release);
if (outdated_ids.empty())
return;
if (cache_not_found_ids.empty())
{
/// Nothing to update - return
if (cache_expired_ids.empty())
return;
std::vector<Key> required_ids(outdated_ids.size());
std::transform(std::begin(outdated_ids), std::end(outdated_ids), std::begin(required_ids), [](auto & pair) { return pair.first; });
/// request new values
update(
required_ids,
[&](const auto id, const auto cell_idx)
/// Update async only if allow_read_expired_keys_is_enabledadd condvar usage and better code
if (allow_read_expired_keys)
{
const auto attribute_value = attribute_array[cell_idx];
std::vector<Key> required_expired_ids;
required_expired_ids.reserve(cache_expired_ids.size());
std::transform(std::begin(cache_expired_ids), std::end(cache_expired_ids), std::back_inserter(required_expired_ids),
[](auto & pair) { return pair.first; });
for (const size_t row : outdated_ids[id])
out[row] = static_cast<OutputType>(attribute_value);
},
[&](const auto id, const auto)
{
for (const size_t row : outdated_ids[id])
out[row] = get_default(row);
});
/// request new values
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_expired_ids);
tryPushToUpdateQueueOrThrow(update_unit_ptr);
/// Nothing to do - return
return;
}
}
/// From this point we have to update all keys sync.
/// Maybe allow_read_expired_keys_from_cache_dictionary is disabled
/// and there no cache_not_found_ids but some cache_expired.
std::vector<Key> required_ids;
required_ids.reserve(cache_not_found_ids.size() + cache_expired_ids.size());
std::transform(
std::begin(cache_not_found_ids), std::end(cache_not_found_ids),
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
std::transform(
std::begin(cache_expired_ids), std::end(cache_expired_ids),
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
auto on_cell_updated = [&] (const auto id, const auto cell_idx)
{
const auto attribute_value = attribute_array[cell_idx];
for (const size_t row : cache_not_found_ids[id])
out[row] = static_cast<OutputType>(attribute_value);
for (const size_t row : cache_expired_ids[id])
out[row] = static_cast<OutputType>(attribute_value);
};
auto on_id_not_found = [&] (const auto id, const auto)
{
for (const size_t row : cache_not_found_ids[id])
out[row] = get_default(row);
for (const size_t row : cache_expired_ids[id])
out[row] = get_default(row);
};
/// Request new values
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_ids, on_cell_updated, on_id_not_found);
tryPushToUpdateQueueOrThrow(update_unit_ptr);
waitForCurrentUpdateFinish(update_unit_ptr);
}
template <typename DefaultGetter>
@ -161,12 +216,13 @@ void CacheDictionary::getItemsString(
out->getOffsets().resize_assume_reserved(0);
/// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
std::unordered_map<Key, std::vector<size_t>> outdated_ids;
std::unordered_map<Key, std::vector<size_t>> cache_expired_ids;
std::unordered_map<Key, std::vector<size_t>> cache_not_found_ids;
/// we are going to store every string separately
std::unordered_map<Key, String> map;
size_t total_length = 0;
size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0;
size_t cache_hit = 0;
{
const ProfilingScopedReadRWLock read_lock{rw_lock, ProfileEvents::DictCacheLockReadNs};
@ -176,17 +232,10 @@ void CacheDictionary::getItemsString(
const auto id = ids[row];
const auto find_result = findCellIdx(id, now);
if (!find_result.valid)
auto insert_value_routine = [&]()
{
outdated_ids[id].push_back(row);
if (find_result.outdated)
++cache_expired;
else
++cache_not_found;
}
else
{
++cache_hit;
const auto & cell_idx = find_result.cell_idx;
const auto & cell = cells[cell_idx];
const auto string_ref = cell.isDefault() ? get_default(row) : attribute_array[cell_idx];
@ -195,37 +244,82 @@ void CacheDictionary::getItemsString(
map[id] = String{string_ref};
total_length += string_ref.size + 1;
};
if (!find_result.valid)
{
if (find_result.outdated)
{
cache_expired_ids[id].push_back(row);
if (allow_read_expired_keys)
insert_value_routine();
} else
cache_not_found_ids[id].push_back(row);
} else
{
++cache_hit;
insert_value_routine();
}
}
}
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired);
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found);
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired_ids.size());
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found_ids.size());
ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit);
query_count.fetch_add(rows, std::memory_order_relaxed);
hit_count.fetch_add(rows - outdated_ids.size(), std::memory_order_release);
hit_count.fetch_add(rows - cache_expired_ids.size() - cache_not_found_ids.size(), std::memory_order_release);
/// request new values
if (!outdated_ids.empty())
/// Async update of expired keys.
if (cache_not_found_ids.empty())
{
std::vector<Key> required_ids(outdated_ids.size());
std::transform(std::begin(outdated_ids), std::end(outdated_ids), std::begin(required_ids), [](auto & pair) { return pair.first; });
if (allow_read_expired_keys && !cache_expired_ids.empty())
{
std::vector<Key> required_expired_ids;
required_expired_ids.reserve(cache_not_found_ids.size());
std::transform(std::begin(cache_expired_ids), std::end(cache_expired_ids),
std::back_inserter(required_expired_ids), [](auto & pair) { return pair.first; });
update(
required_ids,
[&](const auto id, const auto cell_idx)
{
const auto attribute_value = attribute_array[cell_idx];
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_expired_ids);
map[id] = String{attribute_value};
total_length += (attribute_value.size + 1) * outdated_ids[id].size();
},
[&](const auto id, const auto)
{
for (const auto row : outdated_ids[id])
total_length += get_default(row).size + 1;
});
tryPushToUpdateQueueOrThrow(update_unit_ptr);
/// Do not return at this point, because there some extra stuff to do at the end of this method.
}
}
/// Request new values sync.
/// We have request both cache_not_found_ids and cache_expired_ids.
if (!cache_not_found_ids.empty())
{
std::vector<Key> required_ids;
required_ids.reserve(cache_not_found_ids.size() + cache_expired_ids.size());
std::transform(
std::begin(cache_not_found_ids), std::end(cache_not_found_ids),
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
std::transform(
std::begin(cache_expired_ids), std::end(cache_expired_ids),
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
auto on_cell_updated = [&] (const auto id, const auto cell_idx)
{
const auto attribute_value = attribute_array[cell_idx];
map[id] = String{attribute_value};
total_length += (attribute_value.size + 1) * cache_not_found_ids[id].size();
};
auto on_id_not_found = [&] (const auto id, const auto)
{
for (const auto row : cache_not_found_ids[id])
total_length += get_default(row).size + 1;
};
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_ids, on_cell_updated, on_id_not_found);
tryPushToUpdateQueueOrThrow(update_unit_ptr);
waitForCurrentUpdateFinish(update_unit_ptr);
}
out->getChars().reserve(total_length);
@ -240,167 +334,4 @@ void CacheDictionary::getItemsString(
}
}
template <typename PresentIdHandler, typename AbsentIdHandler>
void CacheDictionary::update(
const std::vector<Key> & requested_ids, PresentIdHandler && on_cell_updated, AbsentIdHandler && on_id_not_found) const
{
CurrentMetrics::Increment metric_increment{CurrentMetrics::DictCacheRequests};
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequested, requested_ids.size());
std::unordered_map<Key, UInt8> remaining_ids{requested_ids.size()};
for (const auto id : requested_ids)
remaining_ids.insert({id, 0});
const auto now = std::chrono::system_clock::now();
const ProfilingScopedWriteRWLock write_lock{rw_lock, ProfileEvents::DictCacheLockWriteNs};
if (now > backoff_end_time)
{
try
{
if (error_count)
{
/// Recover after error: we have to clone the source here because
/// it could keep connections which should be reset after error.
source_ptr = source_ptr->clone();
}
Stopwatch watch;
auto stream = source_ptr->loadIds(requested_ids);
stream->readPrefix();
while (const auto block = stream->read())
{
const auto id_column = typeid_cast<const ColumnUInt64 *>(block.safeGetByPosition(0).column.get());
if (!id_column)
throw Exception{name + ": id column has type different from UInt64.", ErrorCodes::TYPE_MISMATCH};
const auto & ids = id_column->getData();
/// cache column pointers
const auto column_ptrs = ext::map<std::vector>(
ext::range(0, attributes.size()), [&block](size_t i) { return block.safeGetByPosition(i + 1).column.get(); });
for (const auto i : ext::range(0, ids.size()))
{
const auto id = ids[i];
const auto find_result = findCellIdx(id, now);
const auto & cell_idx = find_result.cell_idx;
auto & cell = cells[cell_idx];
for (const auto attribute_idx : ext::range(0, attributes.size()))
{
const auto & attribute_column = *column_ptrs[attribute_idx];
auto & attribute = attributes[attribute_idx];
setAttributeValue(attribute, cell_idx, attribute_column[i]);
}
/// if cell id is zero and zero does not map to this cell, then the cell is unused
if (cell.id == 0 && cell_idx != zero_cell_idx)
element_count.fetch_add(1, std::memory_order_relaxed);
cell.id = id;
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
{
std::uniform_int_distribution<UInt64> distribution{dict_lifetime.min_sec, dict_lifetime.max_sec};
cell.setExpiresAt(now + std::chrono::seconds{distribution(rnd_engine)});
}
else
cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
/// inform caller
on_cell_updated(id, cell_idx);
/// mark corresponding id as found
remaining_ids[id] = 1;
}
}
stream->readSuffix();
error_count = 0;
last_exception = std::exception_ptr{};
backoff_end_time = std::chrono::system_clock::time_point{};
ProfileEvents::increment(ProfileEvents::DictCacheRequestTimeNs, watch.elapsed());
}
catch (...)
{
++error_count;
last_exception = std::current_exception();
backoff_end_time = now + std::chrono::seconds(calculateDurationWithBackoff(rnd_engine, error_count));
tryLogException(last_exception, log, "Could not update cache dictionary '" + getFullName() +
"', next update is scheduled at " + ext::to_string(backoff_end_time));
}
}
size_t not_found_num = 0, found_num = 0;
/// Check which ids have not been found and require setting null_value
for (const auto & id_found_pair : remaining_ids)
{
if (id_found_pair.second)
{
++found_num;
continue;
}
++not_found_num;
const auto id = id_found_pair.first;
const auto find_result = findCellIdx(id, now);
const auto & cell_idx = find_result.cell_idx;
auto & cell = cells[cell_idx];
if (error_count)
{
if (find_result.outdated)
{
/// We have expired data for that `id` so we can continue using it.
bool was_default = cell.isDefault();
cell.setExpiresAt(backoff_end_time);
if (was_default)
cell.setDefault();
if (was_default)
on_id_not_found(id, cell_idx);
else
on_cell_updated(id, cell_idx);
continue;
}
/// We don't have expired data for that `id` so all we can do is to rethrow `last_exception`.
std::rethrow_exception(last_exception);
}
/// Check if cell had not been occupied before and increment element counter if it hadn't
if (cell.id == 0 && cell_idx != zero_cell_idx)
element_count.fetch_add(1, std::memory_order_relaxed);
cell.id = id;
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
{
std::uniform_int_distribution<UInt64> distribution{dict_lifetime.min_sec, dict_lifetime.max_sec};
cell.setExpiresAt(now + std::chrono::seconds{distribution(rnd_engine)});
}
else
cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
/// Set null_value for each attribute
cell.setDefault();
for (auto & attribute : attributes)
setDefaultAttributeValue(attribute, cell_idx);
/// inform caller that the cell has not been found
on_id_not_found(id, cell_idx);
}
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequestedMiss, not_found_num);
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequestedFound, found_num);
ProfileEvents::increment(ProfileEvents::DictCacheRequests);
}
}

View File

@ -97,6 +97,7 @@ static FormatSettings getOutputFormatSetting(const Settings & settings, const Co
format_settings.template_settings.resultset_format = settings.format_template_resultset;
format_settings.template_settings.row_format = settings.format_template_row;
format_settings.template_settings.row_between_delimiter = settings.format_template_rows_between_delimiter;
format_settings.tsv.crlf_end_of_line = settings.output_format_tsv_crlf_end_of_line;
format_settings.write_statistics = settings.output_format_write_statistics;
format_settings.parquet.row_group_size = settings.output_format_parquet_row_group_size;
format_settings.schema.format_schema = settings.format_schema;
@ -144,9 +145,19 @@ BlockInputStreamPtr FormatFactory::getInput(
// Doesn't make sense to use parallel parsing with less than four threads
// (segmentator + two parsers + reader).
if (settings.input_format_parallel_parsing
&& file_segmentation_engine
&& settings.max_threads >= 4)
bool parallel_parsing = settings.input_format_parallel_parsing && file_segmentation_engine && settings.max_threads >= 4;
if (parallel_parsing && name == "JSONEachRow")
{
/// FIXME ParallelParsingBlockInputStream doesn't support formats with non-trivial readPrefix() and readSuffix()
/// For JSONEachRow we can safely skip whitespace characters
skipWhitespaceIfAny(buf);
if (buf.eof() || *buf.position() == '[')
parallel_parsing = false; /// Disable it for JSONEachRow if data is in square brackets (see JSONEachRowRowInputFormat)
}
if (parallel_parsing)
{
const auto & input_getter = getCreators(name).input_processor_creator;
if (!input_getter)

View File

@ -1,6 +1,7 @@
#pragma once
#include <Core/Types.h>
#include <Columns/IColumn.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <IO/BufferWithOwnMemory.h>
@ -9,7 +10,6 @@
#include <unordered_map>
#include <boost/noncopyable.hpp>
namespace DB
{
@ -53,7 +53,9 @@ public:
/// This callback allows to perform some additional actions after writing a single row.
/// It's initial purpose was to flush Kafka message for each row.
using WriteCallback = std::function<void()>;
using WriteCallback = std::function<void(
const Columns & columns,
size_t row)>;
private:
using InputCreator = std::function<BlockInputStreamPtr(

View File

@ -64,6 +64,7 @@ struct FormatSettings
struct TSV
{
bool empty_as_default = false;
bool crlf_end_of_line = false;
};
TSV tsv;

View File

@ -45,7 +45,7 @@ try
BlockInputStreamPtr block_input = std::make_shared<InputStreamFromInputFormat>(std::move(input_format));
BlockOutputStreamPtr block_output = std::make_shared<OutputStreamToOutputFormat>(
std::make_shared<TabSeparatedRowOutputFormat>(out_buf, sample, false, false, [] {}, format_settings));
std::make_shared<TabSeparatedRowOutputFormat>(out_buf, sample, false, false, [](const Columns & /* columns */, size_t /* row */){}, format_settings));
copyData(*block_input, *block_output);
return 0;

View File

@ -35,7 +35,6 @@
#include <Columns/ColumnsCommon.h>
#include <Common/FieldVisitors.h>
#include <Common/assert_cast.h>
#include <Interpreters/ExpressionActions.h>
#include <Functions/IFunctionAdaptors.h>
#include <Functions/FunctionsMiscellaneous.h>
#include <Functions/FunctionHelpers.h>

View File

@ -0,0 +1,66 @@
#include <Functions/IFunctionImpl.h>
#include <Functions/FunctionFactory.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/NullWriteBuffer.h>
namespace DB
{
/// Returns size on disk for *block* (without taking into account compression).
class FunctionBlockSerializedSize : public IFunction
{
public:
static constexpr auto name = "blockSerializedSize";
static FunctionPtr create(const Context &)
{
return std::make_shared<FunctionBlockSerializedSize>();
}
String getName() const override { return name; }
bool useDefaultImplementationForNulls() const override { return false; }
size_t getNumberOfArguments() const override { return 0; }
bool isVariadic() const override { return true; }
DataTypePtr getReturnTypeImpl(const DataTypes & /*arguments*/) const override
{
return std::make_shared<DataTypeUInt64>();
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
{
UInt64 size = 0;
for (size_t i = 0; i < arguments.size(); ++i)
size += blockSerializedSizeOne(block.getByPosition(arguments[i]));
block.getByPosition(result).column = DataTypeUInt64().createColumnConst(
input_rows_count, size)->convertToFullColumnIfConst();
}
UInt64 blockSerializedSizeOne(const ColumnWithTypeAndName & elem) const
{
ColumnPtr full_column = elem.column->convertToFullColumnIfConst();
IDataType::SerializeBinaryBulkSettings settings;
NullWriteBuffer out;
settings.getter = [&out](IDataType::SubstreamPath) -> WriteBuffer * { return &out; };
IDataType::SerializeBinaryBulkStatePtr state;
elem.type->serializeBinaryBulkWithMultipleStreams(*full_column,
0 /** offset */, 0 /** limit */,
settings, state);
return out.count();
}
};
void registerFunctionBlockSerializedSize(FunctionFactory & factory)
{
factory.registerFunction<FunctionBlockSerializedSize>();
}
}

View File

@ -73,9 +73,15 @@ public:
for (size_t i = 0; i < input_rows_count; ++i)
{
StringRef source = column_concrete->getDataAt(i);
int status = 0;
std::string demangled = demangle(source.data, status);
result_column->insertDataWithTerminatingZero(demangled.data(), demangled.size() + 1);
auto demangled = tryDemangle(source.data);
if (demangled)
{
result_column->insertDataWithTerminatingZero(demangled.get(), strlen(demangled.get()));
}
else
{
result_column->insertDataWithTerminatingZero(source.data, source.size);
}
}
block.getByPosition(result).column = std::move(result_column);

View File

@ -14,6 +14,7 @@ void registerFunctionFQDN(FunctionFactory &);
void registerFunctionVisibleWidth(FunctionFactory &);
void registerFunctionToTypeName(FunctionFactory &);
void registerFunctionGetSizeOfEnumType(FunctionFactory &);
void registerFunctionBlockSerializedSize(FunctionFactory &);
void registerFunctionToColumnTypeName(FunctionFactory &);
void registerFunctionDumpColumnStructure(FunctionFactory &);
void registerFunctionDefaultValueOfArgumentType(FunctionFactory &);
@ -72,6 +73,7 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory)
registerFunctionVisibleWidth(factory);
registerFunctionToTypeName(factory);
registerFunctionGetSizeOfEnumType(factory);
registerFunctionBlockSerializedSize(factory);
registerFunctionToColumnTypeName(factory);
registerFunctionDumpColumnStructure(factory);
registerFunctionDefaultValueOfArgumentType(factory);

View File

@ -16,6 +16,7 @@ struct ConnectionTimeouts
Poco::Timespan receive_timeout;
Poco::Timespan tcp_keep_alive_timeout;
Poco::Timespan http_keep_alive_timeout;
Poco::Timespan secure_connection_timeout;
ConnectionTimeouts() = default;
@ -26,7 +27,8 @@ struct ConnectionTimeouts
send_timeout(send_timeout_),
receive_timeout(receive_timeout_),
tcp_keep_alive_timeout(0),
http_keep_alive_timeout(0)
http_keep_alive_timeout(0),
secure_connection_timeout(connection_timeout)
{
}
@ -38,7 +40,8 @@ struct ConnectionTimeouts
send_timeout(send_timeout_),
receive_timeout(receive_timeout_),
tcp_keep_alive_timeout(tcp_keep_alive_timeout_),
http_keep_alive_timeout(0)
http_keep_alive_timeout(0),
secure_connection_timeout(connection_timeout)
{
}
ConnectionTimeouts(const Poco::Timespan & connection_timeout_,
@ -50,10 +53,25 @@ struct ConnectionTimeouts
send_timeout(send_timeout_),
receive_timeout(receive_timeout_),
tcp_keep_alive_timeout(tcp_keep_alive_timeout_),
http_keep_alive_timeout(http_keep_alive_timeout_)
http_keep_alive_timeout(http_keep_alive_timeout_),
secure_connection_timeout(connection_timeout)
{
}
ConnectionTimeouts(const Poco::Timespan & connection_timeout_,
const Poco::Timespan & send_timeout_,
const Poco::Timespan & receive_timeout_,
const Poco::Timespan & tcp_keep_alive_timeout_,
const Poco::Timespan & http_keep_alive_timeout_,
const Poco::Timespan & secure_connection_timeout_)
: connection_timeout(connection_timeout_),
send_timeout(send_timeout_),
receive_timeout(receive_timeout_),
tcp_keep_alive_timeout(tcp_keep_alive_timeout_),
http_keep_alive_timeout(http_keep_alive_timeout_),
secure_connection_timeout(secure_connection_timeout_)
{
}
static Poco::Timespan saturate(const Poco::Timespan & timespan, const Poco::Timespan & limit)
{
@ -69,7 +87,8 @@ struct ConnectionTimeouts
saturate(send_timeout, limit),
saturate(receive_timeout, limit),
saturate(tcp_keep_alive_timeout, limit),
saturate(http_keep_alive_timeout, limit));
saturate(http_keep_alive_timeout, limit),
saturate(secure_connection_timeout, limit));
}
/// Timeouts for the case when we have just single attempt to connect.
@ -81,7 +100,7 @@ struct ConnectionTimeouts
/// Timeouts for the case when we will try many addresses in a loop.
static ConnectionTimeouts getTCPTimeoutsWithFailover(const Settings & settings)
{
return ConnectionTimeouts(settings.connect_timeout_with_failover_ms, settings.send_timeout, settings.receive_timeout, settings.tcp_keep_alive_timeout);
return ConnectionTimeouts(settings.connect_timeout_with_failover_ms, settings.send_timeout, settings.receive_timeout, settings.tcp_keep_alive_timeout, 0, settings.connect_timeout_with_failover_secure_ms);
}
static ConnectionTimeouts getHTTPTimeouts(const Context & context)

View File

@ -0,0 +1,16 @@
#include <IO/NullWriteBuffer.h>
namespace DB
{
NullWriteBuffer::NullWriteBuffer(size_t buf_size, char * existing_memory, size_t alignment)
: BufferWithOwnMemory<WriteBuffer>(buf_size, existing_memory, alignment)
{
}
void NullWriteBuffer::nextImpl()
{
}
}

View File

@ -0,0 +1,18 @@
#pragma once
#include <IO/WriteBuffer.h>
#include <IO/BufferWithOwnMemory.h>
#include <boost/noncopyable.hpp>
namespace DB
{
/// Simply do nothing, can be used to measure amount of written bytes.
class NullWriteBuffer : public BufferWithOwnMemory<WriteBuffer>, boost::noncopyable
{
public:
NullWriteBuffer(size_t buf_size = 16<<10, char * existing_memory = nullptr, size_t alignment = false);
void nextImpl() override;
};
}

View File

@ -6,17 +6,28 @@
#include <port/unistd.h>
#include <IO/ReadBufferAIO.h>
#include <fstream>
#include <string>
namespace
{
std::string createTmpFileForEOFtest()
{
char pattern[] = "/tmp/fileXXXXXX";
char * dir = ::mkdtemp(pattern);
return std::string(dir) + "/foo";
if (char * dir = ::mkdtemp(pattern); dir)
{
return std::string(dir) + "/foo";
}
else
{
/// We have no tmp in docker
/// So we have to use root
std::string almost_rand_dir = std::string{"/"} + std::to_string(rand()) + "foo";
return almost_rand_dir;
}
}
void prepare_for_eof(std::string & filename, std::string & buf)
void prepareForEOF(std::string & filename, std::string & buf)
{
static const std::string symbols = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
@ -28,7 +39,7 @@ void prepare_for_eof(std::string & filename, std::string & buf)
for (size_t i = 0; i < n; ++i)
buf += symbols[i % symbols.length()];
std::ofstream out(filename.c_str());
std::ofstream out(filename);
out << buf;
}
@ -39,7 +50,7 @@ TEST(ReadBufferAIOTest, TestReadAfterAIO)
using namespace DB;
std::string data;
std::string file_path;
prepare_for_eof(file_path, data);
prepareForEOF(file_path, data);
ReadBufferAIO testbuf(file_path);
std::string newdata;

View File

@ -1,3 +1,4 @@
#include "Common/quoteString.h"
#include <Common/typeid_cast.h>
#include <Common/PODArray.h>
#include <Core/Row.h>
@ -334,7 +335,7 @@ void ActionsMatcher::visit(const ASTIdentifier & identifier, const ASTPtr & ast,
found = true;
if (found)
throw Exception("Column " + column_name.get(ast) + " is not under aggregate function and not in GROUP BY.",
throw Exception("Column " + backQuote(column_name.get(ast)) + " is not under aggregate function and not in GROUP BY",
ErrorCodes::NOT_AN_AGGREGATE);
/// Special check for WITH statement alias. Add alias action to be able to use this alias.

View File

@ -2,7 +2,6 @@
#include <Parsers/IAST.h>
#include <Interpreters/PreparedSets.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/SubqueryForSet.h>
#include <Interpreters/InDepthNodeVisitor.h>
@ -13,6 +12,9 @@ namespace DB
class Context;
class ASTFunction;
class ExpressionActions;
using ExpressionActionsPtr = std::shared_ptr<ExpressionActions>;
/// The case of an explicit enumeration of values.
SetPtr makeExplicitSet(
const ASTFunction * node, const Block & sample_block, bool create_ordered_set,

View File

@ -0,0 +1,169 @@
#include <Common/typeid_cast.h>
#include <Columns/ColumnArray.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Interpreters/Context.h>
#include <Interpreters/ArrayJoinAction.h>
namespace DB
{
namespace ErrorCodes
{
extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
extern const int TYPE_MISMATCH;
}
ArrayJoinAction::ArrayJoinAction(const NameSet & array_joined_columns_, bool array_join_is_left, const Context & context)
: columns(array_joined_columns_)
, is_left(array_join_is_left)
, is_unaligned(context.getSettingsRef().enable_unaligned_array_join)
{
if (columns.empty())
throw Exception("No arrays to join", ErrorCodes::LOGICAL_ERROR);
if (is_unaligned)
{
function_length = FunctionFactory::instance().get("length", context);
function_greatest = FunctionFactory::instance().get("greatest", context);
function_arrayResize = FunctionFactory::instance().get("arrayResize", context);
}
else if (is_left)
function_builder = FunctionFactory::instance().get("emptyArrayToSingle", context);
}
void ArrayJoinAction::prepare(Block & sample_block)
{
for (const auto & name : columns)
{
ColumnWithTypeAndName & current = sample_block.getByName(name);
const DataTypeArray * array_type = typeid_cast<const DataTypeArray *>(&*current.type);
if (!array_type)
throw Exception("ARRAY JOIN requires array argument", ErrorCodes::TYPE_MISMATCH);
current.type = array_type->getNestedType();
current.column = nullptr;
}
}
void ArrayJoinAction::execute(Block & block, bool dry_run)
{
if (columns.empty())
throw Exception("No arrays to join", ErrorCodes::LOGICAL_ERROR);
ColumnPtr any_array_ptr = block.getByName(*columns.begin()).column->convertToFullColumnIfConst();
const ColumnArray * any_array = typeid_cast<const ColumnArray *>(&*any_array_ptr);
if (!any_array)
throw Exception("ARRAY JOIN of not array: " + *columns.begin(), ErrorCodes::TYPE_MISMATCH);
/// If LEFT ARRAY JOIN, then we create columns in which empty arrays are replaced by arrays with one element - the default value.
std::map<String, ColumnPtr> non_empty_array_columns;
if (is_unaligned)
{
/// Resize all array joined columns to the longest one, (at least 1 if LEFT ARRAY JOIN), padded with default values.
auto rows = block.rows();
auto uint64 = std::make_shared<DataTypeUInt64>();
ColumnWithTypeAndName column_of_max_length;
if (is_left)
column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 1u), uint64, {});
else
column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 0u), uint64, {});
for (const auto & name : columns)
{
auto & src_col = block.getByName(name);
Block tmp_block{src_col, {{}, uint64, {}}};
function_length->build({src_col})->execute(tmp_block, {0}, 1, rows);
Block tmp_block2{
column_of_max_length, tmp_block.safeGetByPosition(1), {{}, uint64, {}}};
function_greatest->build({column_of_max_length, tmp_block.safeGetByPosition(1)})->execute(tmp_block2, {0, 1}, 2, rows);
column_of_max_length = tmp_block2.safeGetByPosition(2);
}
for (const auto & name : columns)
{
auto & src_col = block.getByName(name);
Block tmp_block{src_col, column_of_max_length, {{}, src_col.type, {}}};
function_arrayResize->build({src_col, column_of_max_length})->execute(tmp_block, {0, 1}, 2, rows);
src_col.column = tmp_block.safeGetByPosition(2).column;
any_array_ptr = src_col.column->convertToFullColumnIfConst();
}
any_array = typeid_cast<const ColumnArray *>(&*any_array_ptr);
}
else if (is_left)
{
for (const auto & name : columns)
{
auto src_col = block.getByName(name);
Block tmp_block{src_col, {{}, src_col.type, {}}};
function_builder->build({src_col})->execute(tmp_block, {0}, 1, src_col.column->size(), dry_run);
non_empty_array_columns[name] = tmp_block.safeGetByPosition(1).column;
}
any_array_ptr = non_empty_array_columns.begin()->second->convertToFullColumnIfConst();
any_array = &typeid_cast<const ColumnArray &>(*any_array_ptr);
}
size_t num_columns = block.columns();
for (size_t i = 0; i < num_columns; ++i)
{
ColumnWithTypeAndName & current = block.safeGetByPosition(i);
if (columns.count(current.name))
{
if (!typeid_cast<const DataTypeArray *>(&*current.type))
throw Exception("ARRAY JOIN of not array: " + current.name, ErrorCodes::TYPE_MISMATCH);
ColumnPtr array_ptr = (is_left && !is_unaligned) ? non_empty_array_columns[current.name] : current.column;
array_ptr = array_ptr->convertToFullColumnIfConst();
const ColumnArray & array = typeid_cast<const ColumnArray &>(*array_ptr);
if (!is_unaligned && !array.hasEqualOffsets(typeid_cast<const ColumnArray &>(*any_array_ptr)))
throw Exception("Sizes of ARRAY-JOIN-ed arrays do not match", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
current.column = typeid_cast<const ColumnArray &>(*array_ptr).getDataPtr();
current.type = typeid_cast<const DataTypeArray &>(*current.type).getNestedType();
}
else
{
current.column = current.column->replicate(any_array->getOffsets());
}
}
}
void ArrayJoinAction::finalize(NameSet & needed_columns, NameSet & unmodified_columns, NameSet & final_columns)
{
/// Do not ARRAY JOIN columns that are not used anymore.
/// Usually, such columns are not used until ARRAY JOIN, and therefore are ejected further in this function.
/// We will not remove all the columns so as not to lose the number of rows.
for (auto it = columns.begin(); it != columns.end();)
{
bool need = needed_columns.count(*it);
if (!need && columns.size() > 1)
{
columns.erase(it++);
}
else
{
needed_columns.insert(*it);
unmodified_columns.erase(*it);
/// If no ARRAY JOIN results are used, forcibly leave an arbitrary column at the output,
/// so you do not lose the number of rows.
if (!need)
final_columns.insert(*it);
++it;
}
}
}
}

View File

@ -0,0 +1,35 @@
#pragma once
#include <Core/Names.h>
#include <Core/Block.h>
namespace DB
{
class Context;
class IFunctionOverloadResolver;
using FunctionOverloadResolverPtr = std::shared_ptr<IFunctionOverloadResolver>;
struct ArrayJoinAction
{
NameSet columns;
bool is_left = false;
bool is_unaligned = false;
/// For unaligned [LEFT] ARRAY JOIN
FunctionOverloadResolverPtr function_length;
FunctionOverloadResolverPtr function_greatest;
FunctionOverloadResolverPtr function_arrayResize;
/// For LEFT ARRAY JOIN.
FunctionOverloadResolverPtr function_builder;
ArrayJoinAction(const NameSet & array_joined_columns_, bool array_join_is_left, const Context & context);
void prepare(Block & sample_block);
void execute(Block & block, bool dry_run);
void finalize(NameSet & needed_columns, NameSet & unmodified_columns, NameSet & final_columns);
};
}

View File

@ -57,6 +57,7 @@
#include <Common/TraceCollector.h>
#include <common/logger_useful.h>
#include <Common/RemoteHostFilter.h>
#include <ext/singleton.h>
namespace ProfileEvents
{
@ -168,7 +169,6 @@ struct ContextShared
RemoteHostFilter remote_host_filter; /// Allowed URL from config.xml
std::unique_ptr<TraceCollector> trace_collector; /// Thread collecting traces from threads executing queries
/// Named sessions. The user could specify session identifier to reuse settings and temporary tables in subsequent requests.
class SessionKeyHash
@ -299,13 +299,7 @@ struct ContextShared
schedule_pool.reset();
ddl_worker.reset();
/// Stop trace collector if any
trace_collector.reset();
}
bool hasTraceCollector()
{
return trace_collector != nullptr;
ext::Singleton<TraceCollector>::reset();
}
void initializeTraceCollector(std::shared_ptr<TraceLog> trace_log)
@ -313,7 +307,7 @@ struct ContextShared
if (trace_log == nullptr)
return;
trace_collector = std::make_unique<TraceCollector>(trace_log);
ext::Singleton<TraceCollector>()->setTraceLog(trace_log);
}
};
@ -650,6 +644,10 @@ void Context::checkAccess(const AccessFlags & access, const std::string_view & d
void Context::checkAccess(const AccessRightsElement & access) const { return checkAccessImpl(access); }
void Context::checkAccess(const AccessRightsElements & access) const { return checkAccessImpl(access); }
void Context::switchRowPolicy()
{
row_policy = getAccessControlManager().getRowPolicyContext(client_info.initial_user);
}
void Context::setUsersConfig(const ConfigurationPtr & config)
{
@ -688,7 +686,7 @@ void Context::calculateAccessRights()
{
auto lock = getLock();
if (user)
std::atomic_store(&access_rights, getAccessControlManager().getAccessRightsContext(client_info, user->access, settings, current_database));
std::atomic_store(&access_rights, getAccessControlManager().getAccessRightsContext(user, client_info, settings, current_database));
}
void Context::setProfile(const String & profile)
@ -701,9 +699,18 @@ void Context::setProfile(const String & profile)
settings_constraints = std::move(new_constraints);
}
std::shared_ptr<const User> Context::getUser(const String & user_name) const
std::shared_ptr<const User> Context::getUser() const
{
return shared->access_control_manager.getUser(user_name);
if (!user)
throw Exception("No current user", ErrorCodes::LOGICAL_ERROR);
return user;
}
UUID Context::getUserID() const
{
if (!user)
throw Exception("No current user", ErrorCodes::LOGICAL_ERROR);
return user_id;
}
void Context::setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key)
@ -717,8 +724,9 @@ void Context::setUser(const String & name, const String & password, const Poco::
if (!quota_key.empty())
client_info.quota_key = quota_key;
user_id = shared->access_control_manager.getID<User>(name);
user = shared->access_control_manager.authorizeAndGetUser(
name,
user_id,
password,
address.host(),
[this](const UserPtr & changed_user)
@ -1679,11 +1687,6 @@ void Context::initializeSystemLogs()
shared->system_logs.emplace(*global_context, getConfigRef());
}
bool Context::hasTraceCollector()
{
return shared->hasTraceCollector();
}
void Context::initializeTraceCollector()
{
shared->initializeTraceCollector(getTraceLog());

View File

@ -4,6 +4,7 @@
#include <Core/NamesAndTypes.h>
#include <Core/Settings.h>
#include <Core/Types.h>
#include <Core/UUID.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <Interpreters/ClientInfo.h>
#include <Parsers/IAST_fwd.h>
@ -161,6 +162,7 @@ private:
InputBlocksReader input_blocks_reader;
std::shared_ptr<const User> user;
UUID user_id;
SubscriptionForUserChange subscription_for_user_change;
std::shared_ptr<const AccessRightsContext> access_rights;
std::shared_ptr<QuotaContext> quota; /// Current quota. By default - empty quota, that have no limits.
@ -251,6 +253,10 @@ public:
std::shared_ptr<QuotaContext> getQuota() const { return quota; }
std::shared_ptr<RowPolicyContext> getRowPolicy() const { return row_policy; }
/// TODO: we need much better code for switching policies, quotas, access rights for initial user
/// Switches row policy in case we have initial user in client info
void switchRowPolicy();
/** Take the list of users, quotas and configuration profiles from this config.
* The list of users is completely replaced.
* The accumulated quota values are not reset if the quota is not deleted.
@ -260,10 +266,8 @@ public:
/// Must be called before getClientInfo.
void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key);
std::shared_ptr<const User> getUser() const { return user; }
/// Used by MySQL Secure Password Authentication plugin.
std::shared_ptr<const User> getUser(const String & user_name) const;
std::shared_ptr<const User> getUser() const;
UUID getUserID() const;
/// We have to copy external tables inside executeQuery() to track limits. Therefore, set callback for it. Must set once.
void setExternalTablesInitializer(ExternalTablesInitializer && initializer);

View File

@ -6,13 +6,11 @@
#include <Interpreters/ExpressionJIT.h>
#include <Interpreters/AnalyzedJoin.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnArray.h>
#include <Common/typeid_cast.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/IFunction.h>
#include <set>
#include <optional>
#include <Columns/ColumnSet.h>
#include <Functions/FunctionHelpers.h>
@ -33,20 +31,20 @@ namespace ErrorCodes
extern const int UNKNOWN_IDENTIFIER;
extern const int UNKNOWN_ACTION;
extern const int NOT_FOUND_COLUMN_IN_BLOCK;
extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
extern const int TOO_MANY_TEMPORARY_COLUMNS;
extern const int TOO_MANY_TEMPORARY_NON_CONST_COLUMNS;
extern const int TYPE_MISMATCH;
}
/// Read comment near usage
static constexpr auto DUMMY_COLUMN_NAME = "_dummy";
Names ExpressionAction::getNeededColumns() const
{
Names res = argument_names;
res.insert(res.end(), array_joined_columns.begin(), array_joined_columns.end());
if (array_join)
res.insert(res.end(), array_join->columns.begin(), array_join->columns.end());
if (table_join)
res.insert(res.end(), table_join->keyNamesLeft().begin(), table_join->keyNamesLeft().end());
@ -143,23 +141,9 @@ ExpressionAction ExpressionAction::addAliases(const NamesWithAliases & aliased_c
ExpressionAction ExpressionAction::arrayJoin(const NameSet & array_joined_columns, bool array_join_is_left, const Context & context)
{
if (array_joined_columns.empty())
throw Exception("No arrays to join", ErrorCodes::LOGICAL_ERROR);
ExpressionAction a;
a.type = ARRAY_JOIN;
a.array_joined_columns = array_joined_columns;
a.array_join_is_left = array_join_is_left;
a.unaligned_array_join = context.getSettingsRef().enable_unaligned_array_join;
if (a.unaligned_array_join)
{
a.function_length = FunctionFactory::instance().get("length", context);
a.function_greatest = FunctionFactory::instance().get("greatest", context);
a.function_arrayResize = FunctionFactory::instance().get("arrayResize", context);
}
else if (array_join_is_left)
a.function_builder = FunctionFactory::instance().get("emptyArrayToSingle", context);
a.array_join = std::make_shared<ArrayJoinAction>(array_joined_columns, array_join_is_left, context);
return a;
}
@ -172,7 +156,6 @@ ExpressionAction ExpressionAction::ordinaryJoin(std::shared_ptr<AnalyzedJoin> ta
return a;
}
void ExpressionAction::prepare(Block & sample_block, const Settings & settings, NameSet & names_not_for_constant_folding)
{
// std::cerr << "preparing: " << toString() << std::endl;
@ -256,16 +239,7 @@ void ExpressionAction::prepare(Block & sample_block, const Settings & settings,
case ARRAY_JOIN:
{
for (const auto & name : array_joined_columns)
{
ColumnWithTypeAndName & current = sample_block.getByName(name);
const DataTypeArray * array_type = typeid_cast<const DataTypeArray *>(&*current.type);
if (!array_type)
throw Exception("ARRAY JOIN requires array argument", ErrorCodes::TYPE_MISMATCH);
current.type = array_type->getNestedType();
current.column = nullptr;
}
array_join->prepare(sample_block);
break;
}
@ -383,95 +357,7 @@ void ExpressionAction::execute(Block & block, bool dry_run, ExtraBlockPtr & not_
case ARRAY_JOIN:
{
if (array_joined_columns.empty())
throw Exception("No arrays to join", ErrorCodes::LOGICAL_ERROR);
ColumnPtr any_array_ptr = block.getByName(*array_joined_columns.begin()).column->convertToFullColumnIfConst();
const ColumnArray * any_array = typeid_cast<const ColumnArray *>(&*any_array_ptr);
if (!any_array)
throw Exception("ARRAY JOIN of not array: " + *array_joined_columns.begin(), ErrorCodes::TYPE_MISMATCH);
/// If LEFT ARRAY JOIN, then we create columns in which empty arrays are replaced by arrays with one element - the default value.
std::map<String, ColumnPtr> non_empty_array_columns;
if (unaligned_array_join)
{
/// Resize all array joined columns to the longest one, (at least 1 if LEFT ARRAY JOIN), padded with default values.
auto rows = block.rows();
auto uint64 = std::make_shared<DataTypeUInt64>();
ColumnWithTypeAndName column_of_max_length;
if (array_join_is_left)
column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 1u), uint64, {});
else
column_of_max_length = ColumnWithTypeAndName(uint64->createColumnConst(rows, 0u), uint64, {});
for (const auto & name : array_joined_columns)
{
auto & src_col = block.getByName(name);
Block tmp_block{src_col, {{}, uint64, {}}};
function_length->build({src_col})->execute(tmp_block, {0}, 1, rows);
Block tmp_block2{
column_of_max_length, tmp_block.safeGetByPosition(1), {{}, uint64, {}}};
function_greatest->build({column_of_max_length, tmp_block.safeGetByPosition(1)})->execute(tmp_block2, {0, 1}, 2, rows);
column_of_max_length = tmp_block2.safeGetByPosition(2);
}
for (const auto & name : array_joined_columns)
{
auto & src_col = block.getByName(name);
Block tmp_block{src_col, column_of_max_length, {{}, src_col.type, {}}};
function_arrayResize->build({src_col, column_of_max_length})->execute(tmp_block, {0, 1}, 2, rows);
src_col.column = tmp_block.safeGetByPosition(2).column;
any_array_ptr = src_col.column->convertToFullColumnIfConst();
}
any_array = typeid_cast<const ColumnArray *>(&*any_array_ptr);
}
else if (array_join_is_left)
{
for (const auto & name : array_joined_columns)
{
auto src_col = block.getByName(name);
Block tmp_block{src_col, {{}, src_col.type, {}}};
function_builder->build({src_col})->execute(tmp_block, {0}, 1, src_col.column->size(), dry_run);
non_empty_array_columns[name] = tmp_block.safeGetByPosition(1).column;
}
any_array_ptr = non_empty_array_columns.begin()->second->convertToFullColumnIfConst();
any_array = &typeid_cast<const ColumnArray &>(*any_array_ptr);
}
size_t columns = block.columns();
for (size_t i = 0; i < columns; ++i)
{
ColumnWithTypeAndName & current = block.safeGetByPosition(i);
if (array_joined_columns.count(current.name))
{
if (!typeid_cast<const DataTypeArray *>(&*current.type))
throw Exception("ARRAY JOIN of not array: " + current.name, ErrorCodes::TYPE_MISMATCH);
ColumnPtr array_ptr = (array_join_is_left && !unaligned_array_join) ? non_empty_array_columns[current.name] : current.column;
array_ptr = array_ptr->convertToFullColumnIfConst();
const ColumnArray & array = typeid_cast<const ColumnArray &>(*array_ptr);
if (!unaligned_array_join && !array.hasEqualOffsets(typeid_cast<const ColumnArray &>(*any_array_ptr)))
throw Exception("Sizes of ARRAY-JOIN-ed arrays do not match", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
current.column = typeid_cast<const ColumnArray &>(*array_ptr).getDataPtr();
current.type = typeid_cast<const DataTypeArray &>(*current.type).getNestedType();
}
else
{
current.column = current.column->replicate(any_array->getOffsets());
}
}
array_join->execute(block, dry_run);
break;
}
@ -539,7 +425,6 @@ void ExpressionAction::execute(Block & block, bool dry_run, ExtraBlockPtr & not_
}
}
void ExpressionAction::executeOnTotals(Block & block) const
{
if (type != JOIN)
@ -584,10 +469,10 @@ std::string ExpressionAction::toString() const
break;
case ARRAY_JOIN:
ss << (array_join_is_left ? "LEFT " : "") << "ARRAY JOIN ";
for (NameSet::const_iterator it = array_joined_columns.begin(); it != array_joined_columns.end(); ++it)
ss << (array_join->is_left ? "LEFT " : "") << "ARRAY JOIN ";
for (NameSet::const_iterator it = array_join->columns.begin(); it != array_join->columns.end(); ++it)
{
if (it != array_joined_columns.begin())
if (it != array_join->columns.begin())
ss << ", ";
ss << *it;
}
@ -675,7 +560,9 @@ void ExpressionActions::addImpl(ExpressionAction action, Names & new_names)
{
if (action.result_name != "")
new_names.push_back(action.result_name);
new_names.insert(new_names.end(), action.array_joined_columns.begin(), action.array_joined_columns.end());
if (action.array_join)
new_names.insert(new_names.end(), action.array_join->columns.begin(), action.array_join->columns.end());
/// Compiled functions are custom functions and they don't need building
if (action.type == ExpressionAction::APPLY_FUNCTION && !action.is_function_compiled)
@ -713,7 +600,7 @@ void ExpressionActions::prependArrayJoin(const ExpressionAction & action, const
if (action.type != ExpressionAction::ARRAY_JOIN)
throw Exception("ARRAY_JOIN action expected", ErrorCodes::LOGICAL_ERROR);
NameSet array_join_set(action.array_joined_columns.begin(), action.array_joined_columns.end());
NameSet array_join_set(action.array_join->columns.begin(), action.array_join->columns.end());
for (auto & it : input_columns)
{
if (array_join_set.count(it.name))
@ -738,12 +625,12 @@ bool ExpressionActions::popUnusedArrayJoin(const Names & required_columns, Expre
if (actions.empty() || actions.back().type != ExpressionAction::ARRAY_JOIN)
return false;
NameSet required_set(required_columns.begin(), required_columns.end());
for (const std::string & name : actions.back().array_joined_columns)
for (const std::string & name : actions.back().array_join->columns)
{
if (required_set.count(name))
return false;
}
for (const std::string & name : actions.back().array_joined_columns)
for (const std::string & name : actions.back().array_join->columns)
{
DataTypePtr & type = sample_block.getByName(name).type;
type = std::make_shared<DataTypeArray>(type);
@ -884,29 +771,7 @@ void ExpressionActions::finalize(const Names & output_columns)
}
else if (action.type == ExpressionAction::ARRAY_JOIN)
{
/// Do not ARRAY JOIN columns that are not used anymore.
/// Usually, such columns are not used until ARRAY JOIN, and therefore are ejected further in this function.
/// We will not remove all the columns so as not to lose the number of rows.
for (auto it = action.array_joined_columns.begin(); it != action.array_joined_columns.end();)
{
bool need = needed_columns.count(*it);
if (!need && action.array_joined_columns.size() > 1)
{
action.array_joined_columns.erase(it++);
}
else
{
needed_columns.insert(*it);
unmodified_columns.erase(*it);
/// If no ARRAY JOIN results are used, forcibly leave an arbitrary column at the output,
/// so you do not lose the number of rows.
if (!need)
final_columns.insert(*it);
++it;
}
}
action.array_join->finalize(needed_columns, unmodified_columns, final_columns);
}
else
{
@ -1143,7 +1008,8 @@ void ExpressionActions::optimizeArrayJoin()
if (actions[i].result_name != "")
array_joined_columns.insert(actions[i].result_name);
array_joined_columns.insert(actions[i].array_joined_columns.begin(), actions[i].array_joined_columns.end());
if (actions[i].array_join)
array_joined_columns.insert(actions[i].array_join->columns.begin(), actions[i].array_join->columns.end());
array_join_dependencies.insert(needed.begin(), needed.end());
}
@ -1274,8 +1140,8 @@ UInt128 ExpressionAction::ActionHash::operator()(const ExpressionAction & action
hash.update(arg_name);
break;
case ARRAY_JOIN:
hash.update(action.array_join_is_left);
for (const auto & col : action.array_joined_columns)
hash.update(action.array_join->is_left);
for (const auto & col : action.array_join->columns)
hash.update(col);
break;
case JOIN:
@ -1332,11 +1198,15 @@ bool ExpressionAction::operator==(const ExpressionAction & other) const
return false;
}
bool same_array_join = !array_join && !other.array_join;
if (array_join && other.array_join)
same_array_join = (array_join->columns == other.array_join->columns) &&
(array_join->is_left == other.array_join->is_left);
return source_name == other.source_name
&& result_name == other.result_name
&& argument_names == other.argument_names
&& array_joined_columns == other.array_joined_columns
&& array_join_is_left == other.array_join_is_left
&& same_array_join
&& AnalyzedJoin::sameJoin(table_join.get(), other.table_join.get())
&& projection == other.projection
&& is_function_compiled == other.is_function_compiled;

View File

@ -11,6 +11,7 @@
#include <unordered_map>
#include <unordered_set>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Interpreters/ArrayJoinAction.h>
namespace DB
@ -81,15 +82,10 @@ public:
/// For ADD_COLUMN.
ColumnPtr added_column;
/// For APPLY_FUNCTION and LEFT ARRAY JOIN.
/// For APPLY_FUNCTION.
/// OverloadResolver is used before action was added to ExpressionActions (when we don't know types of arguments).
FunctionOverloadResolverPtr function_builder;
/// For unaligned [LEFT] ARRAY JOIN
FunctionOverloadResolverPtr function_length;
FunctionOverloadResolverPtr function_greatest;
FunctionOverloadResolverPtr function_arrayResize;
/// Can be used after action was added to ExpressionActions if we want to get function signature or properties like monotonicity.
FunctionBasePtr function_base;
/// Prepared function which is used in function execution.
@ -97,10 +93,8 @@ public:
Names argument_names;
bool is_function_compiled = false;
/// For ARRAY_JOIN
NameSet array_joined_columns;
bool array_join_is_left = false;
bool unaligned_array_join = false;
/// For ARRAY JOIN
std::shared_ptr<ArrayJoinAction> array_join;
/// For JOIN
std::shared_ptr<const AnalyzedJoin> table_join;

View File

@ -70,9 +70,51 @@ using LogAST = DebugASTLog<false>; /// set to true to enable logs
namespace ErrorCodes
{
extern const int UNKNOWN_IDENTIFIER;
extern const int ILLEGAL_PREWHERE;
extern const int LOGICAL_ERROR;
}
namespace
{
/// Check if there is an ignore function. It's used for disabling constant folding in query
/// predicates because some performance tests use ignore function as a non-optimize guard.
bool allowEarlyConstantFolding(const ExpressionActions & actions, const Settings & settings)
{
if (!settings.enable_early_constant_folding)
return false;
for (auto & action : actions.getActions())
{
if (action.type == action.APPLY_FUNCTION && action.function_base)
{
auto name = action.function_base->getName();
if (name == "ignore")
return false;
}
}
return true;
}
}
bool sanitizeBlock(Block & block)
{
for (auto & col : block)
{
if (!col.column)
{
if (isNotCreatable(col.type->getTypeId()))
return false;
col.column = col.type->createColumn();
}
else if (isColumnConst(*col.column) && !col.column->empty())
col.column = col.column->cloneEmpty();
}
return true;
}
ExpressionAnalyzer::ExpressionAnalyzer(
const ASTPtr & query_,
const SyntaxAnalyzerResultPtr & syntax_analyzer_result_,
@ -733,7 +775,8 @@ void SelectQueryExpressionAnalyzer::appendSelect(ExpressionActionsChain & chain,
step.required_output.push_back(child->getColumnName());
}
bool SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain, bool only_types, bool optimize_read_in_order)
bool SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain, bool only_types, bool optimize_read_in_order,
ManyExpressionActions & order_by_elements_actions)
{
const auto * select_query = getSelectQuery();
@ -884,12 +927,239 @@ ExpressionActionsPtr ExpressionAnalyzer::getConstActions()
return actions;
}
void SelectQueryExpressionAnalyzer::getAggregateInfo(Names & key_names, AggregateDescriptions & aggregates) const
ExpressionActionsPtr SelectQueryExpressionAnalyzer::simpleSelectActions()
{
for (const auto & name_and_type : aggregation_keys)
key_names.emplace_back(name_and_type.name);
ExpressionActionsChain new_chain(context);
appendSelect(new_chain, false);
return new_chain.getLastActions();
}
aggregates = aggregate_descriptions;
ExpressionAnalysisResult::ExpressionAnalysisResult(
SelectQueryExpressionAnalyzer & query_analyzer,
bool first_stage_,
bool second_stage_,
bool only_types,
const FilterInfoPtr & filter_info_,
const Block & source_header)
: first_stage(first_stage_)
, second_stage(second_stage_)
, need_aggregate(query_analyzer.hasAggregation())
{
/// first_stage: Do I need to perform the first part of the pipeline - running on remote servers during distributed processing.
/// second_stage: Do I need to execute the second part of the pipeline - running on the initiating server during distributed processing.
/** First we compose a chain of actions and remember the necessary steps from it.
* Regardless of from_stage and to_stage, we will compose a complete sequence of actions to perform optimization and
* throw out unnecessary columns based on the entire query. In unnecessary parts of the query, we will not execute subqueries.
*/
const ASTSelectQuery & query = *query_analyzer.getSelectQuery();
const Context & context = query_analyzer.context;
const Settings & settings = context.getSettingsRef();
const StoragePtr & storage = query_analyzer.storage();
bool finalized = false;
size_t where_step_num = 0;
auto finalizeChain = [&](ExpressionActionsChain & chain)
{
if (!finalized)
{
chain.finalize();
finalize(chain, context, where_step_num);
chain.clear();
}
finalized = true;
};
{
ExpressionActionsChain chain(context);
Names additional_required_columns_after_prewhere;
if (storage && (query.sample_size() || settings.parallel_replicas_count > 1))
{
Names columns_for_sampling = storage->getColumnsRequiredForSampling();
additional_required_columns_after_prewhere.insert(additional_required_columns_after_prewhere.end(),
columns_for_sampling.begin(), columns_for_sampling.end());
}
if (storage && query.final())
{
Names columns_for_final = storage->getColumnsRequiredForFinal();
additional_required_columns_after_prewhere.insert(additional_required_columns_after_prewhere.end(),
columns_for_final.begin(), columns_for_final.end());
}
if (storage && filter_info_)
{
filter_info = filter_info_;
query_analyzer.appendPreliminaryFilter(chain, filter_info->actions, filter_info->column_name);
}
if (query_analyzer.appendPrewhere(chain, !first_stage, additional_required_columns_after_prewhere))
{
prewhere_info = std::make_shared<PrewhereInfo>(
chain.steps.front().actions, query.prewhere()->getColumnName());
if (allowEarlyConstantFolding(*prewhere_info->prewhere_actions, settings))
{
Block before_prewhere_sample = source_header;
if (sanitizeBlock(before_prewhere_sample))
{
prewhere_info->prewhere_actions->execute(before_prewhere_sample);
auto & column_elem = before_prewhere_sample.getByName(query.prewhere()->getColumnName());
/// If the filter column is a constant, record it.
if (column_elem.column)
prewhere_constant_filter_description = ConstantFilterDescription(*column_elem.column);
}
}
chain.addStep();
}
query_analyzer.appendArrayJoin(chain, only_types || !first_stage);
if (query_analyzer.appendJoin(chain, only_types || !first_stage))
{
before_join = chain.getLastActions();
if (!hasJoin())
throw Exception("No expected JOIN", ErrorCodes::LOGICAL_ERROR);
chain.addStep();
}
if (query_analyzer.appendWhere(chain, only_types || !first_stage))
{
where_step_num = chain.steps.size() - 1;
before_where = chain.getLastActions();
if (allowEarlyConstantFolding(*before_where, settings))
{
Block before_where_sample;
if (chain.steps.size() > 1)
before_where_sample = chain.steps[chain.steps.size() - 2].actions->getSampleBlock();
else
before_where_sample = source_header;
if (sanitizeBlock(before_where_sample))
{
before_where->execute(before_where_sample);
auto & column_elem = before_where_sample.getByName(query.where()->getColumnName());
/// If the filter column is a constant, record it.
if (column_elem.column)
where_constant_filter_description = ConstantFilterDescription(*column_elem.column);
}
}
chain.addStep();
}
if (need_aggregate)
{
query_analyzer.appendGroupBy(chain, only_types || !first_stage);
query_analyzer.appendAggregateFunctionsArguments(chain, only_types || !first_stage);
before_aggregation = chain.getLastActions();
finalizeChain(chain);
if (query_analyzer.appendHaving(chain, only_types || !second_stage))
{
before_having = chain.getLastActions();
chain.addStep();
}
}
bool has_stream_with_non_joned_rows = (before_join && before_join->getTableJoinAlgo()->hasStreamWithNonJoinedRows());
optimize_read_in_order =
settings.optimize_read_in_order
&& storage && query.orderBy()
&& !query_analyzer.hasAggregation()
&& !query.final()
&& !has_stream_with_non_joned_rows;
/// If there is aggregation, we execute expressions in SELECT and ORDER BY on the initiating server, otherwise on the source servers.
query_analyzer.appendSelect(chain, only_types || (need_aggregate ? !second_stage : !first_stage));
selected_columns = chain.getLastStep().required_output;
has_order_by = query_analyzer.appendOrderBy(chain, only_types || (need_aggregate ? !second_stage : !first_stage),
optimize_read_in_order, order_by_elements_actions);
before_order_and_select = chain.getLastActions();
chain.addStep();
if (query_analyzer.appendLimitBy(chain, only_types || !second_stage))
{
before_limit_by = chain.getLastActions();
chain.addStep();
}
query_analyzer.appendProjectResult(chain);
final_projection = chain.getLastActions();
finalizeChain(chain);
}
/// Before executing WHERE and HAVING, remove the extra columns from the block (mostly the aggregation keys).
removeExtraColumns();
checkActions();
}
void ExpressionAnalysisResult::finalize(const ExpressionActionsChain & chain, const Context & context_, size_t where_step_num)
{
if (hasPrewhere())
{
const ExpressionActionsChain::Step & step = chain.steps.at(0);
prewhere_info->remove_prewhere_column = step.can_remove_required_output.at(0);
Names columns_to_remove;
for (size_t i = 1; i < step.required_output.size(); ++i)
{
if (step.can_remove_required_output[i])
columns_to_remove.push_back(step.required_output[i]);
}
if (!columns_to_remove.empty())
{
auto columns = prewhere_info->prewhere_actions->getSampleBlock().getNamesAndTypesList();
ExpressionActionsPtr actions = std::make_shared<ExpressionActions>(columns, context_);
for (const auto & column : columns_to_remove)
actions->add(ExpressionAction::removeColumn(column));
prewhere_info->remove_columns_actions = std::move(actions);
}
columns_to_remove_after_prewhere = std::move(columns_to_remove);
}
else if (hasFilter())
{
/// Can't have prewhere and filter set simultaneously
filter_info->do_remove_column = chain.steps.at(0).can_remove_required_output.at(0);
}
if (hasWhere())
remove_where_filter = chain.steps.at(where_step_num).can_remove_required_output.at(0);
}
void ExpressionAnalysisResult::removeExtraColumns()
{
if (hasFilter())
filter_info->actions->prependProjectInput();
if (hasWhere())
before_where->prependProjectInput();
if (hasHaving())
before_having->prependProjectInput();
}
void ExpressionAnalysisResult::checkActions()
{
/// Check that PREWHERE doesn't contain unusual actions. Unusual actions are that can change number of rows.
if (hasPrewhere())
{
auto check_actions = [](const ExpressionActionsPtr & actions)
{
if (actions)
for (const auto & action : actions->getActions())
if (action.type == ExpressionAction::Type::JOIN || action.type == ExpressionAction::Type::ARRAY_JOIN)
throw Exception("PREWHERE cannot contain ARRAY JOIN or JOIN action", ErrorCodes::ILLEGAL_PREWHERE);
};
check_actions(prewhere_info->prewhere_actions);
check_actions(prewhere_info->alias_actions);
check_actions(prewhere_info->remove_columns_actions);
}
}
}

View File

@ -2,11 +2,13 @@
#include <Core/Settings.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <Columns/FilterDescription.h>
#include <Interpreters/AggregateDescription.h>
#include <Interpreters/SyntaxAnalyzer.h>
#include <Interpreters/SubqueryForSet.h>
#include <Parsers/IAST_fwd.h>
#include <Storages/IStorage_fwd.h>
#include <Storages/SelectQueryInfo.h>
namespace DB
@ -29,6 +31,9 @@ class ASTExpressionList;
class ASTSelectQuery;
struct ASTTablesInSelectQueryElement;
/// Create columns in block or return false if not possible
bool sanitizeBlock(Block & block);
/// ExpressionAnalyzer sources, intermediates and results. It splits data and logic, allows to test them separately.
struct ExpressionAnalyzerData
{
@ -47,9 +52,6 @@ struct ExpressionAnalyzerData
/// All new temporary tables obtained by performing the GLOBAL IN/JOIN subqueries.
Tables external_tables;
/// Actions by every element of ORDER BY
ManyExpressionActions order_by_elements_actions;
};
@ -156,10 +158,71 @@ protected:
bool isRemoteStorage() const;
};
class SelectQueryExpressionAnalyzer;
/// Result of SelectQueryExpressionAnalyzer: expressions for InterpreterSelectQuery
struct ExpressionAnalysisResult
{
/// Do I need to perform the first part of the pipeline - running on remote servers during distributed processing.
bool first_stage = false;
/// Do I need to execute the second part of the pipeline - running on the initiating server during distributed processing.
bool second_stage = false;
bool need_aggregate = false;
bool has_order_by = false;
bool remove_where_filter = false;
bool optimize_read_in_order = false;
ExpressionActionsPtr before_join; /// including JOIN
ExpressionActionsPtr before_where;
ExpressionActionsPtr before_aggregation;
ExpressionActionsPtr before_having;
ExpressionActionsPtr before_order_and_select;
ExpressionActionsPtr before_limit_by;
ExpressionActionsPtr final_projection;
/// Columns from the SELECT list, before renaming them to aliases.
Names selected_columns;
/// Columns will be removed after prewhere actions execution.
Names columns_to_remove_after_prewhere;
PrewhereInfoPtr prewhere_info;
FilterInfoPtr filter_info;
ConstantFilterDescription prewhere_constant_filter_description;
ConstantFilterDescription where_constant_filter_description;
/// Actions by every element of ORDER BY
ManyExpressionActions order_by_elements_actions;
ExpressionAnalysisResult() = default;
ExpressionAnalysisResult(
SelectQueryExpressionAnalyzer & query_analyzer,
bool first_stage,
bool second_stage,
bool only_types,
const FilterInfoPtr & filter_info,
const Block & source_header);
bool hasFilter() const { return filter_info.get(); }
bool hasJoin() const { return before_join.get(); }
bool hasPrewhere() const { return prewhere_info.get(); }
bool hasWhere() const { return before_where.get(); }
bool hasHaving() const { return before_having.get(); }
bool hasLimitBy() const { return before_limit_by.get(); }
void removeExtraColumns();
void checkActions();
void finalize(const ExpressionActionsChain & chain, const Context & context, size_t where_step_num);
};
/// SelectQuery specific ExpressionAnalyzer part.
class SelectQueryExpressionAnalyzer : public ExpressionAnalyzer
{
public:
friend struct ExpressionAnalysisResult;
SelectQueryExpressionAnalyzer(
const ASTPtr & query_,
const SyntaxAnalyzerResultPtr & syntax_analyzer_result_,
@ -175,16 +238,46 @@ public:
bool hasAggregation() const { return has_aggregation; }
bool hasGlobalSubqueries() { return has_global_subqueries; }
/// Get a list of aggregation keys and descriptions of aggregate functions if the query contains GROUP BY.
void getAggregateInfo(Names & key_names, AggregateDescriptions & aggregates) const;
const NamesAndTypesList & aggregationKeys() const { return aggregation_keys; }
const AggregateDescriptions & aggregates() const { return aggregate_descriptions; }
/// Create Set-s that we make from IN section to use index on them.
void makeSetsForIndex(const ASTPtr & node);
const PreparedSets & getPreparedSets() const { return prepared_sets; }
const ManyExpressionActions & getOrderByActions() const { return order_by_elements_actions; }
/// Tables that will need to be sent to remote servers for distributed query processing.
const Tables & getExternalTables() const { return external_tables; }
ExpressionActionsPtr simpleSelectActions();
/// These appends are public only for tests
void appendSelect(ExpressionActionsChain & chain, bool only_types);
/// Deletes all columns except mentioned by SELECT, arranges the remaining columns and renames them to aliases.
void appendProjectResult(ExpressionActionsChain & chain) const;
private:
/// If non-empty, ignore all expressions not from this list.
NameSet required_result_columns;
/**
* Create Set from a subquery or a table expression in the query. The created set is suitable for using the index.
* The set will not be created if its size hits the limit.
*/
void tryMakeSetForIndexFromSubquery(const ASTPtr & subquery_or_table_name);
/**
* Checks if subquery is not a plain StorageSet.
* Because while making set we will read data from StorageSet which is not allowed.
* Returns valid SetPtr from StorageSet if the latter is used after IN or nullptr otherwise.
*/
SetPtr isPlainStorageSetInSubquery(const ASTPtr & subquery_of_table_name);
JoinPtr makeTableJoin(const ASTTablesInSelectQueryElement & join_element);
void makeSubqueryForJoin(const ASTTablesInSelectQueryElement & join_element, NamesWithAliases && required_columns_with_aliases,
SubqueryForSet & subquery_for_set) const;
const ASTSelectQuery * getAggregatingQuery() const;
/** These methods allow you to build a chain of transformations over a block, that receives values in the desired sections of the query.
*
* Example usage:
@ -213,37 +306,10 @@ public:
/// After aggregation:
bool appendHaving(ExpressionActionsChain & chain, bool only_types);
void appendSelect(ExpressionActionsChain & chain, bool only_types);
bool appendOrderBy(ExpressionActionsChain & chain, bool only_types, bool optimize_read_in_order);
/// appendSelect
bool appendOrderBy(ExpressionActionsChain & chain, bool only_types, bool optimize_read_in_order, ManyExpressionActions &);
bool appendLimitBy(ExpressionActionsChain & chain, bool only_types);
/// Deletes all columns except mentioned by SELECT, arranges the remaining columns and renames them to aliases.
void appendProjectResult(ExpressionActionsChain & chain) const;
/// Create Set-s that we can from IN section to use the index on them.
void makeSetsForIndex(const ASTPtr & node);
private:
/// If non-empty, ignore all expressions not from this list.
NameSet required_result_columns;
/**
* Create Set from a subquery or a table expression in the query. The created set is suitable for using the index.
* The set will not be created if its size hits the limit.
*/
void tryMakeSetForIndexFromSubquery(const ASTPtr & subquery_or_table_name);
/**
* Checks if subquery is not a plain StorageSet.
* Because while making set we will read data from StorageSet which is not allowed.
* Returns valid SetPtr from StorageSet if the latter is used after IN or nullptr otherwise.
*/
SetPtr isPlainStorageSetInSubquery(const ASTPtr & subquery_of_table_name);
JoinPtr makeTableJoin(const ASTTablesInSelectQueryElement & join_element);
void makeSubqueryForJoin(const ASTTablesInSelectQueryElement & join_element, NamesWithAliases && required_columns_with_aliases,
SubqueryForSet & subquery_for_set) const;
const ASTSelectQuery * getAggregatingQuery() const;
/// appendProjectResult
};
}

View File

@ -14,7 +14,7 @@ ExternalModelsLoader::ExternalModelsLoader(Context & context_)
: ExternalLoader("external model", &Logger::get("ExternalModelsLoader"))
, context(context_)
{
setConfigSettings({"models", "name", {}});
setConfigSettings({"model", "name", {}});
enablePeriodicUpdates(true);
}

View File

@ -322,7 +322,7 @@ ColumnsDescription InterpreterCreateQuery::getColumnsDescription(const ASTExpres
{
auto syntax_analyzer_result = SyntaxAnalyzer(context).analyze(default_expr_list, column_names_and_types);
const auto actions = ExpressionAnalyzer(default_expr_list, syntax_analyzer_result, context).getActions(true);
for (auto action : actions->getActions())
for (auto & action : actions->getActions())
if (action.type == ExpressionAction::Type::JOIN || action.type == ExpressionAction::Type::ARRAY_JOIN)
throw Exception("Cannot CREATE table. Unsupported default value that requires ARRAY JOIN or JOIN action", ErrorCodes::THERE_IS_NO_DEFAULT_VALUE);
@ -545,10 +545,12 @@ BlockIO InterpreterCreateQuery::createTable(ASTCreateQuery & create)
// If this is a stub ATTACH query, read the query definition from the database
if (create.attach && !create.storage && !create.columns_list)
{
bool if_not_exists = create.if_not_exists;
// Table SQL definition is available even if the table is detached
auto query = context.getDatabase(create.database)->getCreateTableQuery(context, create.table);
create = query->as<ASTCreateQuery &>(); // Copy the saved create query, but use ATTACH instead of CREATE
create.attach = true;
create.if_not_exists = if_not_exists;
}
String current_database = context.getCurrentDatabase();

View File

@ -6,6 +6,7 @@
#include <Storages/IndicesDescription.h>
#include <Storages/ConstraintsDescription.h>
#include <Common/ThreadPool.h>
#include <Access/AccessRightsElement.h>
namespace DB

View File

@ -0,0 +1,72 @@
#include <Interpreters/InterpreterCreateUserQuery.h>
#include <Parsers/ASTCreateUserQuery.h>
#include <Interpreters/Context.h>
#include <Access/AccessControlManager.h>
#include <Access/User.h>
namespace DB
{
BlockIO InterpreterCreateUserQuery::execute()
{
const auto & query = query_ptr->as<const ASTCreateUserQuery &>();
auto & access_control = context.getAccessControlManager();
context.checkAccess(query.alter ? AccessType::ALTER_USER : AccessType::CREATE_USER);
if (query.alter)
{
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{
auto updated_user = typeid_cast<std::shared_ptr<User>>(entity->clone());
updateUserFromQuery(*updated_user, query);
return updated_user;
};
if (query.if_exists)
{
if (auto id = access_control.find<User>(query.name))
access_control.tryUpdate(*id, update_func);
}
else
access_control.update(access_control.getID<User>(query.name), update_func);
}
else
{
auto new_user = std::make_shared<User>();
updateUserFromQuery(*new_user, query);
if (query.if_not_exists)
access_control.tryInsert(new_user);
else if (query.or_replace)
access_control.insertOrReplace(new_user);
else
access_control.insert(new_user);
}
return {};
}
void InterpreterCreateUserQuery::updateUserFromQuery(User & user, const ASTCreateUserQuery & query)
{
if (query.alter)
{
if (!query.new_name.empty())
user.setName(query.new_name);
}
else
user.setName(query.name);
if (query.authentication)
user.authentication = *query.authentication;
if (query.hosts)
user.allowed_client_hosts = *query.hosts;
if (query.remove_hosts)
user.allowed_client_hosts.remove(*query.remove_hosts);
if (query.add_hosts)
user.allowed_client_hosts.add(*query.add_hosts);
if (query.profile)
user.profile = *query.profile;
}
}

View File

@ -0,0 +1,26 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class ASTCreateUserQuery;
struct User;
class InterpreterCreateUserQuery : public IInterpreter
{
public:
InterpreterCreateUserQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {}
BlockIO execute() override;
private:
void updateUserFromQuery(User & quota, const ASTCreateUserQuery & query);
ASTPtr query_ptr;
Context & context;
};
}

View File

@ -5,6 +5,7 @@
#include <Access/AccessFlags.h>
#include <Access/Quota.h>
#include <Access/RowPolicy.h>
#include <Access/User.h>
#include <boost/range/algorithm/transform.hpp>
@ -18,6 +19,16 @@ BlockIO InterpreterDropAccessEntityQuery::execute()
switch (query.kind)
{
case Kind::USER:
{
context.checkAccess(AccessType::DROP_USER);
if (query.if_exists)
access_control.tryRemove(access_control.find<User>(query.names));
else
access_control.remove(access_control.getIDs<User>(query.names));
return {};
}
case Kind::QUOTA:
{
context.checkAccess(AccessType::DROP_QUOTA);
@ -27,6 +38,7 @@ BlockIO InterpreterDropAccessEntityQuery::execute()
access_control.remove(access_control.getIDs<Quota>(query.names));
return {};
}
case Kind::ROW_POLICY:
{
context.checkAccess(AccessType::DROP_POLICY);

View File

@ -1,6 +1,7 @@
#include <Parsers/ASTAlterQuery.h>
#include <Parsers/ASTCheckQuery.h>
#include <Parsers/ASTCreateQuery.h>
#include <Parsers/ASTCreateUserQuery.h>
#include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Parsers/ASTDropAccessEntityQuery.h>
@ -14,6 +15,7 @@
#include <Parsers/ASTSetQuery.h>
#include <Parsers/ASTShowCreateAccessEntityQuery.h>
#include <Parsers/ASTShowProcesslistQuery.h>
#include <Parsers/ASTShowGrantsQuery.h>
#include <Parsers/ASTShowQuotasQuery.h>
#include <Parsers/ASTShowRowPoliciesQuery.h>
#include <Parsers/ASTShowTablesQuery.h>
@ -21,10 +23,12 @@
#include <Parsers/ASTExplainQuery.h>
#include <Parsers/TablePropertiesQueriesASTs.h>
#include <Parsers/ASTWatchQuery.h>
#include <Parsers/ASTGrantQuery.h>
#include <Interpreters/InterpreterAlterQuery.h>
#include <Interpreters/InterpreterCheckQuery.h>
#include <Interpreters/InterpreterCreateQuery.h>
#include <Interpreters/InterpreterCreateUserQuery.h>
#include <Interpreters/InterpreterCreateQuotaQuery.h>
#include <Interpreters/InterpreterCreateRowPolicyQuery.h>
#include <Interpreters/InterpreterDescribeQuery.h>
@ -43,12 +47,14 @@
#include <Interpreters/InterpreterShowCreateAccessEntityQuery.h>
#include <Interpreters/InterpreterShowCreateQuery.h>
#include <Interpreters/InterpreterShowProcesslistQuery.h>
#include <Interpreters/InterpreterShowGrantsQuery.h>
#include <Interpreters/InterpreterShowQuotasQuery.h>
#include <Interpreters/InterpreterShowRowPoliciesQuery.h>
#include <Interpreters/InterpreterShowTablesQuery.h>
#include <Interpreters/InterpreterSystemQuery.h>
#include <Interpreters/InterpreterUseQuery.h>
#include <Interpreters/InterpreterWatchQuery.h>
#include <Interpreters/InterpreterGrantQuery.h>
#include <Parsers/ASTSystemQuery.h>
@ -176,6 +182,10 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, Context &
{
return std::make_unique<InterpreterWatchQuery>(query, context);
}
else if (query->as<ASTCreateUserQuery>())
{
return std::make_unique<InterpreterCreateUserQuery>(query, context);
}
else if (query->as<ASTCreateQuotaQuery>())
{
return std::make_unique<InterpreterCreateQuotaQuery>(query, context);
@ -188,10 +198,18 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, Context &
{
return std::make_unique<InterpreterDropAccessEntityQuery>(query, context);
}
else if (query->as<ASTGrantQuery>())
{
return std::make_unique<InterpreterGrantQuery>(query, context);
}
else if (query->as<ASTShowCreateAccessEntityQuery>())
{
return std::make_unique<InterpreterShowCreateAccessEntityQuery>(query, context);
}
else if (query->as<ASTShowGrantsQuery>())
{
return std::make_unique<InterpreterShowGrantsQuery>(query, context);
}
else if (query->as<ASTShowQuotasQuery>())
{
return std::make_unique<InterpreterShowQuotasQuery>(query, context);

View File

@ -0,0 +1,57 @@
#include <Interpreters/InterpreterGrantQuery.h>
#include <Parsers/ASTGrantQuery.h>
#include <Parsers/ASTRoleList.h>
#include <Interpreters/Context.h>
#include <Access/AccessControlManager.h>
#include <Access/AccessRightsContext.h>
#include <Access/User.h>
namespace DB
{
BlockIO InterpreterGrantQuery::execute()
{
const auto & query = query_ptr->as<const ASTGrantQuery &>();
auto & access_control = context.getAccessControlManager();
context.getAccessRights()->checkGrantOption(query.access_rights_elements);
using Kind = ASTGrantQuery::Kind;
if (query.to_roles->all_roles)
throw Exception(
"Cannot " + String((query.kind == Kind::GRANT) ? "GRANT to" : "REVOKE from") + " ALL", ErrorCodes::NOT_IMPLEMENTED);
String current_database = context.getCurrentDatabase();
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{
auto updated_user = typeid_cast<std::shared_ptr<User>>(entity->clone());
if (query.kind == Kind::GRANT)
{
updated_user->access.grant(query.access_rights_elements, current_database);
if (query.grant_option)
updated_user->access_with_grant_option.grant(query.access_rights_elements, current_database);
}
else if (context.getSettingsRef().partial_revokes)
{
updated_user->access_with_grant_option.partialRevoke(query.access_rights_elements, current_database);
if (!query.grant_option)
updated_user->access.partialRevoke(query.access_rights_elements, current_database);
}
else
{
updated_user->access_with_grant_option.revoke(query.access_rights_elements, current_database);
if (!query.grant_option)
updated_user->access.revoke(query.access_rights_elements, current_database);
}
return updated_user;
};
std::vector<UUID> ids = access_control.getIDs<User>(query.to_roles->roles);
if (query.to_roles->current_user)
ids.push_back(context.getUserID());
access_control.update(ids, update_func);
return {};
}
}

View File

@ -0,0 +1,20 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class InterpreterGrantQuery : public IInterpreter
{
public:
InterpreterGrantQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {}
BlockIO execute() override;
private:
ASTPtr query_ptr;
Context & context;
};
}

View File

@ -154,9 +154,7 @@ String InterpreterSelectQuery::generateFilterActions(ExpressionActionsPtr & acti
/// Using separate expression analyzer to prevent any possible alias injection
auto syntax_result = SyntaxAnalyzer(*context).analyze(query_ast, storage->getColumns().getAllPhysical());
SelectQueryExpressionAnalyzer analyzer(query_ast, syntax_result, *context);
ExpressionActionsChain new_chain(*context);
analyzer.appendSelect(new_chain, false);
actions = new_chain.getLastActions();
actions = analyzer.simpleSelectActions();
return expr_list->children.at(0)->getColumnName();
}
@ -212,17 +210,6 @@ static Context getSubqueryContext(const Context & context)
return subquery_context;
}
static void sanitizeBlock(Block & block)
{
for (auto & col : block)
{
if (!col.column)
col.column = col.type->createColumn();
else if (isColumnConst(*col.column) && !col.column->empty())
col.column = col.column->cloneEmpty();
}
}
InterpreterSelectQuery::InterpreterSelectQuery(
const ASTPtr & query_ptr_,
const Context & context_,
@ -324,7 +311,7 @@ InterpreterSelectQuery::InterpreterSelectQuery(
table_id = storage->getStorageID();
}
auto analyze = [&] ()
auto analyze = [&] (bool try_move_to_prewhere = true)
{
syntax_analyzer_result = SyntaxAnalyzer(*context, options).analyze(
query_ptr, source_header.getNamesAndTypesList(), required_result_column_names, storage, NamesAndTypesList());
@ -397,7 +384,7 @@ InterpreterSelectQuery::InterpreterSelectQuery(
throw Exception("PREWHERE is not supported if the table is filtered by row-level security expression", ErrorCodes::ILLEGAL_PREWHERE);
/// Calculate structure of the result.
result_header = getSampleBlockImpl();
result_header = getSampleBlockImpl(try_move_to_prewhere);
};
analyze();
@ -425,8 +412,13 @@ InterpreterSelectQuery::InterpreterSelectQuery(
query.setExpression(ASTSelectQuery::Expression::WHERE, makeASTFunction("and", query.prewhere()->clone(), query.where()->clone()));
need_analyze_again = true;
}
if (need_analyze_again)
analyze();
{
/// Do not try move conditions to PREWHERE for the second time.
/// Otherwise, we won't be able to fallback from inefficient PREWHERE to WHERE later.
analyze(/* try_move_to_prewhere = */ false);
}
/// If there is no WHERE, filter blocks as usual
if (query.prewhere() && !query.where())
@ -509,7 +501,7 @@ QueryPipeline InterpreterSelectQuery::executeWithProcessors()
}
Block InterpreterSelectQuery::getSampleBlockImpl()
Block InterpreterSelectQuery::getSampleBlockImpl(bool try_move_to_prewhere)
{
auto & query = getSelectQuery();
const Settings & settings = context->getSettingsRef();
@ -533,7 +525,7 @@ Block InterpreterSelectQuery::getSampleBlockImpl()
current_info.sets = query_analyzer->getPreparedSets();
/// Try transferring some condition from WHERE to PREWHERE if enabled and viable
if (settings.optimize_move_to_prewhere && query.where() && !query.prewhere() && !query.final())
if (settings.optimize_move_to_prewhere && try_move_to_prewhere && query.where() && !query.prewhere() && !query.final())
MergeTreeWhereOptimizer{current_info, *context, merge_tree,
syntax_analyzer_result->requiredSourceColumns(), log};
};
@ -546,13 +538,17 @@ Block InterpreterSelectQuery::getSampleBlockImpl()
if (storage && !options.only_analyze)
from_stage = storage->getQueryProcessingStage(*context);
analysis_result = analyzeExpressions(
getSelectQuery(),
/// Do I need to perform the first part of the pipeline - running on remote servers during distributed processing.
bool first_stage = from_stage < QueryProcessingStage::WithMergeableState
&& options.to_stage >= QueryProcessingStage::WithMergeableState;
/// Do I need to execute the second part of the pipeline - running on the initiating server during distributed processing.
bool second_stage = from_stage <= QueryProcessingStage::WithMergeableState
&& options.to_stage > QueryProcessingStage::WithMergeableState;
analysis_result = ExpressionAnalysisResult(
*query_analyzer,
from_stage,
options.to_stage,
*context,
storage,
first_stage,
second_stage,
options.only_analyze,
filter_info,
source_header
@ -579,16 +575,12 @@ Block InterpreterSelectQuery::getSampleBlockImpl()
auto header = analysis_result.before_aggregation->getSampleBlock();
Names key_names;
AggregateDescriptions aggregates;
query_analyzer->getAggregateInfo(key_names, aggregates);
Block res;
for (auto & key : key_names)
res.insert({nullptr, header.getByName(key).type, key});
for (auto & key : query_analyzer->aggregationKeys())
res.insert({nullptr, header.getByName(key.name).type, key.name});
for (auto & aggregate : aggregates)
for (auto & aggregate : query_analyzer->aggregates())
{
size_t arguments_size = aggregate.argument_names.size();
DataTypes argument_types(arguments_size);
@ -606,246 +598,6 @@ Block InterpreterSelectQuery::getSampleBlockImpl()
return analysis_result.final_projection->getSampleBlock();
}
/// Check if there is an ignore function. It's used for disabling constant folding in query
/// predicates because some performance tests use ignore function as a non-optimize guard.
static bool hasIgnore(const ExpressionActions & actions)
{
for (auto & action : actions.getActions())
{
if (action.type == action.APPLY_FUNCTION && action.function_base)
{
auto name = action.function_base->getName();
if (name == "ignore")
return true;
}
}
return false;
}
InterpreterSelectQuery::AnalysisResult
InterpreterSelectQuery::analyzeExpressions(
const ASTSelectQuery & query,
SelectQueryExpressionAnalyzer & query_analyzer,
QueryProcessingStage::Enum from_stage,
QueryProcessingStage::Enum to_stage,
const Context & context,
const StoragePtr & storage,
bool only_types,
const FilterInfoPtr & filter_info,
const Block & source_header)
{
AnalysisResult res;
/// Do I need to perform the first part of the pipeline - running on remote servers during distributed processing.
res.first_stage = from_stage < QueryProcessingStage::WithMergeableState
&& to_stage >= QueryProcessingStage::WithMergeableState;
/// Do I need to execute the second part of the pipeline - running on the initiating server during distributed processing.
res.second_stage = from_stage <= QueryProcessingStage::WithMergeableState
&& to_stage > QueryProcessingStage::WithMergeableState;
/** First we compose a chain of actions and remember the necessary steps from it.
* Regardless of from_stage and to_stage, we will compose a complete sequence of actions to perform optimization and
* throw out unnecessary columns based on the entire query. In unnecessary parts of the query, we will not execute subqueries.
*/
bool has_filter = false;
bool has_prewhere = false;
bool has_where = false;
size_t where_step_num;
auto finalizeChain = [&](ExpressionActionsChain & chain)
{
chain.finalize();
if (has_prewhere)
{
const ExpressionActionsChain::Step & step = chain.steps.at(0);
res.prewhere_info->remove_prewhere_column = step.can_remove_required_output.at(0);
Names columns_to_remove;
for (size_t i = 1; i < step.required_output.size(); ++i)
{
if (step.can_remove_required_output[i])
columns_to_remove.push_back(step.required_output[i]);
}
if (!columns_to_remove.empty())
{
auto columns = res.prewhere_info->prewhere_actions->getSampleBlock().getNamesAndTypesList();
ExpressionActionsPtr actions = std::make_shared<ExpressionActions>(columns, context);
for (const auto & column : columns_to_remove)
actions->add(ExpressionAction::removeColumn(column));
res.prewhere_info->remove_columns_actions = std::move(actions);
}
res.columns_to_remove_after_prewhere = std::move(columns_to_remove);
}
else if (has_filter)
{
/// Can't have prewhere and filter set simultaneously
res.filter_info->do_remove_column = chain.steps.at(0).can_remove_required_output.at(0);
}
if (has_where)
res.remove_where_filter = chain.steps.at(where_step_num).can_remove_required_output.at(0);
has_filter = has_prewhere = has_where = false;
chain.clear();
};
{
ExpressionActionsChain chain(context);
Names additional_required_columns_after_prewhere;
if (storage && (query.sample_size() || context.getSettingsRef().parallel_replicas_count > 1))
{
Names columns_for_sampling = storage->getColumnsRequiredForSampling();
additional_required_columns_after_prewhere.insert(additional_required_columns_after_prewhere.end(),
columns_for_sampling.begin(), columns_for_sampling.end());
}
if (storage && query.final())
{
Names columns_for_final = storage->getColumnsRequiredForFinal();
additional_required_columns_after_prewhere.insert(additional_required_columns_after_prewhere.end(),
columns_for_final.begin(), columns_for_final.end());
}
if (storage && filter_info)
{
has_filter = true;
res.filter_info = filter_info;
query_analyzer.appendPreliminaryFilter(chain, filter_info->actions, filter_info->column_name);
}
if (query_analyzer.appendPrewhere(chain, !res.first_stage, additional_required_columns_after_prewhere))
{
has_prewhere = true;
res.prewhere_info = std::make_shared<PrewhereInfo>(
chain.steps.front().actions, query.prewhere()->getColumnName());
if (!hasIgnore(*res.prewhere_info->prewhere_actions))
{
Block before_prewhere_sample = source_header;
sanitizeBlock(before_prewhere_sample);
res.prewhere_info->prewhere_actions->execute(before_prewhere_sample);
auto & column_elem = before_prewhere_sample.getByName(query.prewhere()->getColumnName());
/// If the filter column is a constant, record it.
if (column_elem.column)
res.prewhere_constant_filter_description = ConstantFilterDescription(*column_elem.column);
}
chain.addStep();
}
res.need_aggregate = query_analyzer.hasAggregation();
query_analyzer.appendArrayJoin(chain, only_types || !res.first_stage);
if (query_analyzer.appendJoin(chain, only_types || !res.first_stage))
{
res.before_join = chain.getLastActions();
if (!res.hasJoin())
throw Exception("No expected JOIN", ErrorCodes::LOGICAL_ERROR);
chain.addStep();
}
if (query_analyzer.appendWhere(chain, only_types || !res.first_stage))
{
where_step_num = chain.steps.size() - 1;
has_where = res.has_where = true;
res.before_where = chain.getLastActions();
if (!hasIgnore(*res.before_where))
{
Block before_where_sample;
if (chain.steps.size() > 1)
before_where_sample = chain.steps[chain.steps.size() - 2].actions->getSampleBlock();
else
before_where_sample = source_header;
sanitizeBlock(before_where_sample);
res.before_where->execute(before_where_sample);
auto & column_elem = before_where_sample.getByName(query.where()->getColumnName());
/// If the filter column is a constant, record it.
if (column_elem.column)
res.where_constant_filter_description = ConstantFilterDescription(*column_elem.column);
}
chain.addStep();
}
if (res.need_aggregate)
{
query_analyzer.appendGroupBy(chain, only_types || !res.first_stage);
query_analyzer.appendAggregateFunctionsArguments(chain, only_types || !res.first_stage);
res.before_aggregation = chain.getLastActions();
finalizeChain(chain);
if (query_analyzer.appendHaving(chain, only_types || !res.second_stage))
{
res.has_having = true;
res.before_having = chain.getLastActions();
chain.addStep();
}
}
bool has_stream_with_non_joned_rows = (res.before_join && res.before_join->getTableJoinAlgo()->hasStreamWithNonJoinedRows());
res.optimize_read_in_order =
context.getSettingsRef().optimize_read_in_order
&& storage && query.orderBy()
&& !query_analyzer.hasAggregation()
&& !query.final()
&& !has_stream_with_non_joned_rows;
/// If there is aggregation, we execute expressions in SELECT and ORDER BY on the initiating server, otherwise on the source servers.
query_analyzer.appendSelect(chain, only_types || (res.need_aggregate ? !res.second_stage : !res.first_stage));
res.selected_columns = chain.getLastStep().required_output;
res.has_order_by = query_analyzer.appendOrderBy(chain, only_types || (res.need_aggregate ? !res.second_stage : !res.first_stage), res.optimize_read_in_order);
res.before_order_and_select = chain.getLastActions();
chain.addStep();
if (query_analyzer.appendLimitBy(chain, only_types || !res.second_stage))
{
res.has_limit_by = true;
res.before_limit_by = chain.getLastActions();
chain.addStep();
}
query_analyzer.appendProjectResult(chain);
res.final_projection = chain.getLastActions();
finalizeChain(chain);
}
/// Before executing WHERE and HAVING, remove the extra columns from the block (mostly the aggregation keys).
if (res.filter_info)
res.filter_info->actions->prependProjectInput();
if (res.has_where)
res.before_where->prependProjectInput();
if (res.has_having)
res.before_having->prependProjectInput();
res.subqueries_for_sets = query_analyzer.getSubqueriesForSets();
/// Check that PREWHERE doesn't contain unusual actions. Unusual actions are that can change number of rows.
if (res.prewhere_info)
{
auto check_actions = [](const ExpressionActionsPtr & actions)
{
if (actions)
for (const auto & action : actions->getActions())
if (action.type == ExpressionAction::Type::JOIN || action.type == ExpressionAction::Type::ARRAY_JOIN)
throw Exception("PREWHERE cannot contain ARRAY JOIN or JOIN action", ErrorCodes::ILLEGAL_PREWHERE);
};
check_actions(res.prewhere_info->prewhere_actions);
check_actions(res.prewhere_info->alias_actions);
check_actions(res.prewhere_info->remove_columns_actions);
}
return res;
}
static Field getWithFillFieldValue(const ASTPtr & node, const Context & context)
{
const auto & [field, type] = evaluateConstantExpression(node, context);
@ -989,6 +741,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
auto & query = getSelectQuery();
const Settings & settings = context->getSettingsRef();
auto & expressions = analysis_result;
auto & subqueries_for_sets = query_analyzer->getSubqueriesForSets();
if (options.only_analyze)
{
@ -1077,7 +830,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
if (expressions.first_stage)
{
if (expressions.filter_info)
if (expressions.hasFilter())
{
if constexpr (pipeline_with_processors)
{
@ -1159,7 +912,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
}
}
if (expressions.has_where)
if (expressions.hasWhere())
executeWhere(pipeline, expressions.before_where, expressions.remove_where_filter);
if (expressions.need_aggregate)
@ -1175,7 +928,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
* but there is an ORDER or LIMIT,
* then we will perform the preliminary sorting and LIMIT on the remote server.
*/
if (!expressions.second_stage && !expressions.need_aggregate && !expressions.has_having)
if (!expressions.second_stage && !expressions.need_aggregate && !expressions.hasHaving())
{
if (expressions.has_order_by)
executeOrder(pipeline, query_info.input_sorting_info);
@ -1183,7 +936,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
if (expressions.has_order_by && query.limitLength())
executeDistinct(pipeline, false, expressions.selected_columns);
if (expressions.has_limit_by)
if (expressions.hasLimitBy())
{
executeExpression(pipeline, expressions.before_limit_by);
executeLimitBy(pipeline);
@ -1194,8 +947,8 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
}
// If there is no global subqueries, we can run subqueries only when receive them on server.
if (!query_analyzer->hasGlobalSubqueries() && !expressions.subqueries_for_sets.empty())
executeSubqueriesInSetsAndJoins(pipeline, expressions.subqueries_for_sets);
if (!query_analyzer->hasGlobalSubqueries() && !subqueries_for_sets.empty())
executeSubqueriesInSetsAndJoins(pipeline, subqueries_for_sets);
}
if (expressions.second_stage)
@ -1213,7 +966,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
if (query.group_by_with_totals)
{
bool final = !query.group_by_with_rollup && !query.group_by_with_cube;
executeTotalsAndHaving(pipeline, expressions.has_having, expressions.before_having, aggregate_overflow_row, final);
executeTotalsAndHaving(pipeline, expressions.hasHaving(), expressions.before_having, aggregate_overflow_row, final);
}
if (query.group_by_with_rollup)
@ -1221,14 +974,14 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
else if (query.group_by_with_cube)
executeRollupOrCube(pipeline, Modificator::CUBE);
if ((query.group_by_with_rollup || query.group_by_with_cube) && expressions.has_having)
if ((query.group_by_with_rollup || query.group_by_with_cube) && expressions.hasHaving())
{
if (query.group_by_with_totals)
throw Exception("WITH TOTALS and WITH ROLLUP or CUBE are not supported together in presence of HAVING", ErrorCodes::NOT_IMPLEMENTED);
executeHaving(pipeline, expressions.before_having);
}
}
else if (expressions.has_having)
else if (expressions.hasHaving())
executeHaving(pipeline, expressions.before_having);
executeExpression(pipeline, expressions.before_order_and_select);
@ -1256,7 +1009,8 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
/** Optimization - if there are several sources and there is LIMIT, then first apply the preliminary LIMIT,
* limiting the number of rows in each up to `offset + limit`.
*/
if (query.limitLength() && !query.limit_with_ties && pipeline.hasMoreThanOneStream() && !query.distinct && !expressions.has_limit_by && !settings.extremes)
if (query.limitLength() && !query.limit_with_ties && pipeline.hasMoreThanOneStream() &&
!query.distinct && !expressions.hasLimitBy() && !settings.extremes)
{
executePreLimit(pipeline);
}
@ -1281,7 +1035,7 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
if (need_second_distinct_pass)
executeDistinct(pipeline, false, expressions.selected_columns);
if (expressions.has_limit_by)
if (expressions.hasLimitBy())
{
executeExpression(pipeline, expressions.before_limit_by);
executeLimitBy(pipeline);
@ -1301,8 +1055,8 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
}
}
if (query_analyzer->hasGlobalSubqueries() && !expressions.subqueries_for_sets.empty())
executeSubqueriesInSetsAndJoins(pipeline, expressions.subqueries_for_sets);
if (query_analyzer->hasGlobalSubqueries() && !subqueries_for_sets.empty())
executeSubqueriesInSetsAndJoins(pipeline, subqueries_for_sets);
}
template <typename TPipeline>
@ -1324,9 +1078,7 @@ void InterpreterSelectQuery::executeFetchColumns(
|| !query_analyzer->hasAggregation() || processing_stage != QueryProcessingStage::FetchColumns)
return {};
Names key_names;
AggregateDescriptions aggregates;
query_analyzer->getAggregateInfo(key_names, aggregates);
const AggregateDescriptions & aggregates = query_analyzer->aggregates();
if (aggregates.size() != 1)
return {};
@ -1639,7 +1391,7 @@ void InterpreterSelectQuery::executeFetchColumns(
if (analysis_result.optimize_read_in_order)
{
query_info.order_by_optimizer = std::make_shared<ReadInOrderOptimizer>(
query_analyzer->getOrderByActions(),
analysis_result.order_by_elements_actions,
getSortDescription(query, *context),
query_info.syntax_analyzer_result);
@ -1866,14 +1618,12 @@ void InterpreterSelectQuery::executeAggregation(Pipeline & pipeline, const Expre
stream = std::make_shared<ExpressionBlockInputStream>(stream, expression);
});
Names key_names;
AggregateDescriptions aggregates;
query_analyzer->getAggregateInfo(key_names, aggregates);
Block header = pipeline.firstStream()->getHeader();
ColumnNumbers keys;
for (const auto & name : key_names)
keys.push_back(header.getPositionByName(name));
for (const auto & key : query_analyzer->aggregationKeys())
keys.push_back(header.getPositionByName(key.name));
AggregateDescriptions aggregates = query_analyzer->aggregates();
for (auto & descr : aggregates)
if (descr.arguments.empty())
for (const auto & name : descr.argument_names)
@ -1932,14 +1682,12 @@ void InterpreterSelectQuery::executeAggregation(QueryPipeline & pipeline, const
return std::make_shared<ExpressionTransform>(header, expression);
});
Names key_names;
AggregateDescriptions aggregates;
query_analyzer->getAggregateInfo(key_names, aggregates);
Block header_before_aggregation = pipeline.getHeader();
ColumnNumbers keys;
for (const auto & name : key_names)
keys.push_back(header_before_aggregation.getPositionByName(name));
for (const auto & key : query_analyzer->aggregationKeys())
keys.push_back(header_before_aggregation.getPositionByName(key.name));
AggregateDescriptions aggregates = query_analyzer->aggregates();
for (auto & descr : aggregates)
if (descr.arguments.empty())
for (const auto & name : descr.argument_names)
@ -2000,15 +1748,11 @@ void InterpreterSelectQuery::executeAggregation(QueryPipeline & pipeline, const
void InterpreterSelectQuery::executeMergeAggregated(Pipeline & pipeline, bool overflow_row, bool final)
{
Names key_names;
AggregateDescriptions aggregates;
query_analyzer->getAggregateInfo(key_names, aggregates);
Block header = pipeline.firstStream()->getHeader();
ColumnNumbers keys;
for (const auto & name : key_names)
keys.push_back(header.getPositionByName(name));
for (const auto & key : query_analyzer->aggregationKeys())
keys.push_back(header.getPositionByName(key.name));
/** There are two modes of distributed aggregation.
*
@ -2027,7 +1771,7 @@ void InterpreterSelectQuery::executeMergeAggregated(Pipeline & pipeline, bool ov
const Settings & settings = context->getSettingsRef();
Aggregator::Params params(header, keys, aggregates, overflow_row, settings.max_threads);
Aggregator::Params params(header, keys, query_analyzer->aggregates(), overflow_row, settings.max_threads);
if (!settings.distributed_aggregation_memory_efficient)
{
@ -2051,15 +1795,11 @@ void InterpreterSelectQuery::executeMergeAggregated(Pipeline & pipeline, bool ov
void InterpreterSelectQuery::executeMergeAggregated(QueryPipeline & pipeline, bool overflow_row, bool final)
{
Names key_names;
AggregateDescriptions aggregates;
query_analyzer->getAggregateInfo(key_names, aggregates);
Block header_before_merge = pipeline.getHeader();
ColumnNumbers keys;
for (const auto & name : key_names)
keys.push_back(header_before_merge.getPositionByName(name));
for (const auto & key : query_analyzer->aggregationKeys())
keys.push_back(header_before_merge.getPositionByName(key.name));
/** There are two modes of distributed aggregation.
*
@ -2078,7 +1818,7 @@ void InterpreterSelectQuery::executeMergeAggregated(QueryPipeline & pipeline, bo
const Settings & settings = context->getSettingsRef();
Aggregator::Params params(header_before_merge, keys, aggregates, overflow_row, settings.max_threads);
Aggregator::Params params(header_before_merge, keys, query_analyzer->aggregates(), overflow_row, settings.max_threads);
auto transform_params = std::make_shared<AggregatingTransformParams>(params, final);
@ -2167,20 +1907,16 @@ void InterpreterSelectQuery::executeRollupOrCube(Pipeline & pipeline, Modificato
{
executeUnion(pipeline, {});
Names key_names;
AggregateDescriptions aggregates;
query_analyzer->getAggregateInfo(key_names, aggregates);
Block header = pipeline.firstStream()->getHeader();
ColumnNumbers keys;
for (const auto & name : key_names)
keys.push_back(header.getPositionByName(name));
for (const auto & key : query_analyzer->aggregationKeys())
keys.push_back(header.getPositionByName(key.name));
const Settings & settings = context->getSettingsRef();
Aggregator::Params params(header, keys, aggregates,
Aggregator::Params params(header, keys, query_analyzer->aggregates(),
false, settings.max_rows_to_group_by, settings.group_by_overflow_mode,
SettingUInt64(0), SettingUInt64(0),
settings.max_bytes_before_external_group_by, settings.empty_result_for_aggregation_by_empty_set,
@ -2196,20 +1932,16 @@ void InterpreterSelectQuery::executeRollupOrCube(QueryPipeline & pipeline, Modif
{
pipeline.resize(1);
Names key_names;
AggregateDescriptions aggregates;
query_analyzer->getAggregateInfo(key_names, aggregates);
Block header_before_transform = pipeline.getHeader();
ColumnNumbers keys;
for (const auto & name : key_names)
keys.push_back(header_before_transform.getPositionByName(name));
for (const auto & key : query_analyzer->aggregationKeys())
keys.push_back(header_before_transform.getPositionByName(key.name));
const Settings & settings = context->getSettingsRef();
Aggregator::Params params(header_before_transform, keys, aggregates,
Aggregator::Params params(header_before_transform, keys, query_analyzer->aggregates(),
false, settings.max_rows_to_group_by, settings.group_by_overflow_mode,
SettingUInt64(0), SettingUInt64(0),
settings.max_bytes_before_external_group_by, settings.empty_result_for_aggregation_by_empty_set,
@ -2806,7 +2538,7 @@ void InterpreterSelectQuery::executeExtremes(QueryPipeline & pipeline)
}
void InterpreterSelectQuery::executeSubqueriesInSetsAndJoins(Pipeline & pipeline, SubqueriesForSets & subqueries_for_sets)
void InterpreterSelectQuery::executeSubqueriesInSetsAndJoins(Pipeline & pipeline, const SubqueriesForSets & subqueries_for_sets)
{
/// Merge streams to one. Use MergeSorting if data was read in sorted order, Union otherwise.
if (query_info.input_sorting_info)
@ -2822,7 +2554,7 @@ void InterpreterSelectQuery::executeSubqueriesInSetsAndJoins(Pipeline & pipeline
pipeline.firstStream(), subqueries_for_sets, *context);
}
void InterpreterSelectQuery::executeSubqueriesInSetsAndJoins(QueryPipeline & pipeline, SubqueriesForSets & subqueries_for_sets)
void InterpreterSelectQuery::executeSubqueriesInSetsAndJoins(QueryPipeline & pipeline, const SubqueriesForSets & subqueries_for_sets)
{
if (query_info.input_sorting_info)
executeMergeSorted(pipeline, query_info.input_sorting_info->order_key_prefix_descr, 0);

View File

@ -104,7 +104,7 @@ private:
ASTSelectQuery & getSelectQuery() { return query_ptr->as<ASTSelectQuery &>(); }
Block getSampleBlockImpl();
Block getSampleBlockImpl(bool try_move_to_prewhere);
struct Pipeline
{
@ -152,55 +152,6 @@ private:
template <typename TPipeline>
void executeImpl(TPipeline & pipeline, const BlockInputStreamPtr & prepared_input, std::optional<Pipe> prepared_pipe, QueryPipeline & save_context_and_storage);
struct AnalysisResult
{
bool hasJoin() const { return before_join.get(); }
bool has_where = false;
bool need_aggregate = false;
bool has_having = false;
bool has_order_by = false;
bool has_limit_by = false;
bool remove_where_filter = false;
bool optimize_read_in_order = false;
ExpressionActionsPtr before_join; /// including JOIN
ExpressionActionsPtr before_where;
ExpressionActionsPtr before_aggregation;
ExpressionActionsPtr before_having;
ExpressionActionsPtr before_order_and_select;
ExpressionActionsPtr before_limit_by;
ExpressionActionsPtr final_projection;
/// Columns from the SELECT list, before renaming them to aliases.
Names selected_columns;
/// Columns will be removed after prewhere actions execution.
Names columns_to_remove_after_prewhere;
/// Do I need to perform the first part of the pipeline - running on remote servers during distributed processing.
bool first_stage = false;
/// Do I need to execute the second part of the pipeline - running on the initiating server during distributed processing.
bool second_stage = false;
SubqueriesForSets subqueries_for_sets;
PrewhereInfoPtr prewhere_info;
FilterInfoPtr filter_info;
ConstantFilterDescription prewhere_constant_filter_description;
ConstantFilterDescription where_constant_filter_description;
};
static AnalysisResult analyzeExpressions(
const ASTSelectQuery & query,
SelectQueryExpressionAnalyzer & query_analyzer,
QueryProcessingStage::Enum from_stage,
QueryProcessingStage::Enum to_stage,
const Context & context,
const StoragePtr & storage,
bool only_types,
const FilterInfoPtr & filter_info,
const Block & source_header);
/** From which table to read. With JOIN, the "left" table is returned.
*/
static void getDatabaseAndTableNames(const ASTSelectQuery & query, String & database_name, String & table_name, const Context & context);
@ -232,7 +183,7 @@ private:
void executeProjection(Pipeline & pipeline, const ExpressionActionsPtr & expression);
void executeDistinct(Pipeline & pipeline, bool before_order, Names columns);
void executeExtremes(Pipeline & pipeline);
void executeSubqueriesInSetsAndJoins(Pipeline & pipeline, std::unordered_map<String, SubqueryForSet> & subqueries_for_sets);
void executeSubqueriesInSetsAndJoins(Pipeline & pipeline, const std::unordered_map<String, SubqueryForSet> & subqueries_for_sets);
void executeMergeSorted(Pipeline & pipeline, const SortDescription & sort_description, UInt64 limit);
void executeWhere(QueryPipeline & pipeline, const ExpressionActionsPtr & expression, bool remove_fiter);
@ -250,7 +201,7 @@ private:
void executeProjection(QueryPipeline & pipeline, const ExpressionActionsPtr & expression);
void executeDistinct(QueryPipeline & pipeline, bool before_order, Names columns);
void executeExtremes(QueryPipeline & pipeline);
void executeSubqueriesInSetsAndJoins(QueryPipeline & pipeline, std::unordered_map<String, SubqueryForSet> & subqueries_for_sets);
void executeSubqueriesInSetsAndJoins(QueryPipeline & pipeline, const std::unordered_map<String, SubqueryForSet> & subqueries_for_sets);
void executeMergeSorted(QueryPipeline & pipeline, const SortDescription & sort_description, UInt64 limit);
String generateFilterActions(ExpressionActionsPtr & actions, const ASTPtr & row_policy_filter, const Names & prerequisite_columns = {}) const;
@ -284,7 +235,7 @@ private:
SelectQueryInfo query_info;
/// Is calculated in getSampleBlock. Is used later in readImpl.
AnalysisResult analysis_result;
ExpressionAnalysisResult analysis_result;
FilterInfoPtr filter_info;
QueryProcessingStage::Enum from_stage = QueryProcessingStage::FetchColumns;

View File

@ -1,5 +1,6 @@
#include <Interpreters/InterpreterShowCreateAccessEntityQuery.h>
#include <Interpreters/Context.h>
#include <Parsers/ASTCreateUserQuery.h>
#include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Parsers/ASTShowCreateAccessEntityQuery.h>
@ -9,6 +10,7 @@
#include <Parsers/parseQuery.h>
#include <Access/AccessControlManager.h>
#include <Access/QuotaContext.h>
#include <Access/User.h>
#include <Columns/ColumnString.h>
#include <DataStreams/OneBlockInputStream.h>
#include <DataTypes/DataTypeString.h>
@ -58,6 +60,7 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuery(const ASTShowCreat
using Kind = ASTShowCreateAccessEntityQuery::Kind;
switch (show_query.kind)
{
case Kind::USER: return getCreateUserQuery(show_query);
case Kind::QUOTA: return getCreateQuotaQuery(show_query);
case Kind::ROW_POLICY: return getCreateRowPolicyQuery(show_query);
}
@ -65,6 +68,27 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuery(const ASTShowCreat
}
ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateUserQuery(const ASTShowCreateAccessEntityQuery & show_query) const
{
UserPtr user;
if (show_query.current_user)
user = context.getUser();
else
user = context.getAccessControlManager().getUser(show_query.name);
auto create_query = std::make_shared<ASTCreateUserQuery>();
create_query->name = user->getName();
if (!user->allowed_client_hosts.containsAnyHost())
create_query->hosts = user->allowed_client_hosts;
if (!user->profile.empty())
create_query->profile = user->profile;
return create_query;
}
ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuotaQuery(const ASTShowCreateAccessEntityQuery & show_query) const
{
auto & access_control = context.getAccessControlManager();

View File

@ -29,6 +29,7 @@ private:
BlockInputStreamPtr executeImpl();
ASTPtr getCreateQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
ASTPtr getCreateUserQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
ASTPtr getCreateQuotaQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
ASTPtr getCreateRowPolicyQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
};

View File

@ -0,0 +1,124 @@
#include <Interpreters/InterpreterShowGrantsQuery.h>
#include <Parsers/ASTShowGrantsQuery.h>
#include <Parsers/ASTGrantQuery.h>
#include <Parsers/ASTRoleList.h>
#include <Parsers/formatAST.h>
#include <Interpreters/Context.h>
#include <Columns/ColumnString.h>
#include <DataStreams/OneBlockInputStream.h>
#include <DataTypes/DataTypeString.h>
#include <Access/AccessControlManager.h>
#include <Access/User.h>
#include <boost/range/adaptor/map.hpp>
#include <boost/range/algorithm/copy.hpp>
namespace DB
{
namespace
{
std::vector<AccessRightsElements> groupByTable(AccessRightsElements && elements)
{
using Key = std::tuple<String, bool, String, bool>;
std::map<Key, AccessRightsElements> grouping_map;
for (auto & element : elements)
{
Key key(element.database, element.any_database, element.table, element.any_table);
grouping_map[key].emplace_back(std::move(element));
}
std::vector<AccessRightsElements> res;
res.reserve(grouping_map.size());
boost::range::copy(grouping_map | boost::adaptors::map_values, std::back_inserter(res));
return res;
}
struct GroupedGrantsAndPartialRevokes
{
std::vector<AccessRightsElements> grants;
std::vector<AccessRightsElements> partial_revokes;
};
GroupedGrantsAndPartialRevokes groupByTable(AccessRights::Elements && elements)
{
GroupedGrantsAndPartialRevokes res;
res.grants = groupByTable(std::move(elements.grants));
res.partial_revokes = groupByTable(std::move(elements.partial_revokes));
return res;
}
}
BlockIO InterpreterShowGrantsQuery::execute()
{
BlockIO res;
res.in = executeImpl();
return res;
}
BlockInputStreamPtr InterpreterShowGrantsQuery::executeImpl()
{
const auto & show_query = query_ptr->as<ASTShowGrantsQuery &>();
/// Build a create query.
ASTs grant_queries = getGrantQueries(show_query);
/// Build the result column.
MutableColumnPtr column = ColumnString::create();
std::stringstream grant_ss;
for (const auto & grant_query : grant_queries)
{
grant_ss.str("");
formatAST(*grant_query, grant_ss, false, true);
column->insert(grant_ss.str());
}
/// Prepare description of the result column.
std::stringstream desc_ss;
formatAST(show_query, desc_ss, false, true);
String desc = desc_ss.str();
String prefix = "SHOW ";
if (desc.starts_with(prefix))
desc = desc.substr(prefix.length()); /// `desc` always starts with "SHOW ", so we can trim this prefix.
return std::make_shared<OneBlockInputStream>(Block{{std::move(column), std::make_shared<DataTypeString>(), desc}});
}
ASTs InterpreterShowGrantsQuery::getGrantQueries(const ASTShowGrantsQuery & show_query) const
{
UserPtr user;
if (show_query.current_user)
user = context.getUser();
else
user = context.getAccessControlManager().getUser(show_query.name);
ASTs res;
for (bool grant_option : {true, false})
{
if (!grant_option && (user->access == user->access_with_grant_option))
continue;
const auto & access_rights = grant_option ? user->access_with_grant_option : user->access;
const auto grouped_elements = groupByTable(access_rights.getElements());
using Kind = ASTGrantQuery::Kind;
for (Kind kind : {Kind::GRANT, Kind::REVOKE})
{
for (const auto & elements : (kind == Kind::GRANT ? grouped_elements.grants : grouped_elements.partial_revokes))
{
auto grant_query = std::make_shared<ASTGrantQuery>();
grant_query->kind = kind;
grant_query->grant_option = grant_option;
grant_query->to_roles = std::make_shared<ASTRoleList>();
grant_query->to_roles->roles.push_back(user->getName());
grant_query->access_rights_elements = elements;
res.push_back(std::move(grant_query));
}
}
}
return res;
}
}

View File

@ -0,0 +1,26 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class ASTShowGrantsQuery;
class InterpreterShowGrantsQuery : public IInterpreter
{
public:
InterpreterShowGrantsQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {}
BlockIO execute() override;
private:
BlockInputStreamPtr executeImpl();
ASTs getGrantQueries(const ASTShowGrantsQuery & show_query) const;
ASTPtr query_ptr;
Context & context;
};
}

View File

@ -18,6 +18,7 @@
#include <Parsers/ASTExpressionList.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/formatAST.h>
#include <Parsers/ASTOrderByElement.h>
#include <IO/WriteHelpers.h>
@ -525,6 +526,39 @@ ASTPtr MutationsInterpreter::prepareInterpreterSelectQuery(std::vector<Stage> &
}
select->setExpression(ASTSelectQuery::Expression::WHERE, std::move(where_expression));
}
auto metadata = storage->getInMemoryMetadata();
/// We have to execute select in order of primary key
/// because we don't sort results additionaly and don't have
/// any guarantees on data order without ORDER BY. It's almost free, because we
/// have optimization for data read in primary key order.
if (metadata.order_by_ast)
{
ASTPtr dummy;
ASTPtr key_expr;
if (metadata.primary_key_ast)
key_expr = metadata.primary_key_ast;
else
key_expr = metadata.order_by_ast;
bool empty = false;
/// In all other cases we cannot have empty key
if (auto key_function = key_expr->as<ASTFunction>())
empty = key_function->arguments->children.size() == 0;
/// Not explicitely spicified empty key
if (!empty)
{
auto order_by_expr = std::make_shared<ASTOrderByElement>(1, 1, false, dummy, false, dummy, dummy, dummy);
order_by_expr->children.push_back(key_expr);
auto res = std::make_shared<ASTExpressionList>();
res->children.push_back(order_by_expr);
select->setExpression(ASTSelectQuery::Expression::ORDER_BY, std::move(res));
}
}
return select;
}

View File

@ -181,12 +181,12 @@ ProcessList::EntryPtr ProcessList::insert(const String & query_, const IAST * as
/// You should specify this value in configuration for default profile,
/// not for specific users, sessions or queries,
/// because this setting is effectively global.
total_memory_tracker.setOrRaiseLimit(settings.max_memory_usage_for_all_queries);
total_memory_tracker.setOrRaiseHardLimit(settings.max_memory_usage_for_all_queries);
total_memory_tracker.setDescription("(total)");
/// Track memory usage for all simultaneously running queries from single user.
user_process_list.user_memory_tracker.setParent(&total_memory_tracker);
user_process_list.user_memory_tracker.setOrRaiseLimit(settings.max_memory_usage_for_user);
user_process_list.user_memory_tracker.setOrRaiseHardLimit(settings.max_memory_usage_for_user);
user_process_list.user_memory_tracker.setDescription("(for user)");
/// Actualize thread group info
@ -198,7 +198,9 @@ ProcessList::EntryPtr ProcessList::insert(const String & query_, const IAST * as
thread_group->query = process_it->query;
/// Set query-level memory trackers
thread_group->memory_tracker.setOrRaiseLimit(process_it->max_memory_usage);
thread_group->memory_tracker.setOrRaiseHardLimit(process_it->max_memory_usage);
thread_group->memory_tracker.setOrRaiseProfilerLimit(settings.memory_profiler_step);
thread_group->memory_tracker.setProfilerStep(settings.memory_profiler_step);
thread_group->memory_tracker.setDescription("(for query)");
if (process_it->memory_tracker_fault_probability)
thread_group->memory_tracker.setFaultProbability(process_it->memory_tracker_fault_probability);

View File

@ -2,6 +2,7 @@
#include <Interpreters/InterpreterSelectWithUnionQuery.h>
#include <Interpreters/Join.h>
#include <Interpreters/MergeJoin.h>
#include <Interpreters/ExpressionActions.h>
#include <DataStreams/LazyBlockInputStream.h>
namespace DB

Some files were not shown because too many files have changed in this diff Show More