Merge branch 'master' of github.com:ClickHouse/ClickHouse

This commit is contained in:
Slach 2020-02-07 08:03:45 +05:00
commit 1bff5578b6
113 changed files with 4567 additions and 992 deletions

View File

@ -218,7 +218,7 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl
try
{
// For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used.
auto user = connection_context.getUser(user_name);
auto user = connection_context.getAccessControlManager().getUser(user_name);
const DB::Authentication::Type user_auth_type = user->authentication.getType();
if (user_auth_type != DB::Authentication::DOUBLE_SHA1_PASSWORD && user_auth_type != DB::Authentication::PLAINTEXT_PASSWORD && user_auth_type != DB::Authentication::NO_PASSWORD)
{

View File

@ -35,45 +35,49 @@ AccessControlManager::~AccessControlManager()
}
UserPtr AccessControlManager::getUser(const String & user_name) const
UserPtr AccessControlManager::getUser(
const String & user_name, std::function<void(const UserPtr &)> on_change, ext::scope_guard * subscription) const
{
return getUser(user_name, {}, nullptr);
return getUser(getID<User>(user_name), std::move(on_change), subscription);
}
UserPtr AccessControlManager::getUser(
const String & user_name, const std::function<void(const UserPtr &)> & on_change, ext::scope_guard * subscription) const
const UUID & user_id, std::function<void(const UserPtr &)> on_change, ext::scope_guard * subscription) const
{
UUID id = getID<User>(user_name);
if (on_change && subscription)
{
*subscription = subscribeForChanges(id, [on_change](const UUID &, const AccessEntityPtr & user)
*subscription = subscribeForChanges(user_id, [on_change](const UUID &, const AccessEntityPtr & user)
{
if (user)
on_change(typeid_cast<UserPtr>(user));
});
}
return read<User>(id);
return read<User>(user_id);
}
UserPtr AccessControlManager::authorizeAndGetUser(
const String & user_name,
const String & password,
const Poco::Net::IPAddress & address) const
{
return authorizeAndGetUser(user_name, password, address, {}, nullptr);
}
UserPtr AccessControlManager::authorizeAndGetUser(
const String & user_name,
const String & password,
const Poco::Net::IPAddress & address,
const std::function<void(const UserPtr &)> & on_change,
std::function<void(const UserPtr &)> on_change,
ext::scope_guard * subscription) const
{
auto user = getUser(user_name, on_change, subscription);
user->allowed_client_hosts.checkContains(address, user_name);
user->authentication.checkPassword(password, user_name);
return authorizeAndGetUser(getID<User>(user_name), password, address, std::move(on_change), subscription);
}
UserPtr AccessControlManager::authorizeAndGetUser(
const UUID & user_id,
const String & password,
const Poco::Net::IPAddress & address,
std::function<void(const UserPtr &)> on_change,
ext::scope_guard * subscription) const
{
auto user = getUser(user_id, on_change, subscription);
user->allowed_client_hosts.checkContains(address, user->getName());
user->authentication.checkPassword(password, user->getName());
return user;
}
@ -85,9 +89,9 @@ void AccessControlManager::loadFromConfig(const Poco::Util::AbstractConfiguratio
}
std::shared_ptr<const AccessRightsContext> AccessControlManager::getAccessRightsContext(const ClientInfo & client_info, const AccessRights & granted_to_user, const Settings & settings, const String & current_database)
std::shared_ptr<const AccessRightsContext> AccessControlManager::getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database)
{
return std::make_shared<AccessRightsContext>(client_info, granted_to_user, settings, current_database);
return std::make_shared<AccessRightsContext>(user, client_info, settings, current_database);
}

View File

@ -42,12 +42,12 @@ public:
void loadFromConfig(const Poco::Util::AbstractConfiguration & users_config);
UserPtr getUser(const String & user_name) const;
UserPtr getUser(const String & user_name, const std::function<void(const UserPtr &)> & on_change, ext::scope_guard * subscription) const;
UserPtr authorizeAndGetUser(const String & user_name, const String & password, const Poco::Net::IPAddress & address) const;
UserPtr authorizeAndGetUser(const String & user_name, const String & password, const Poco::Net::IPAddress & address, const std::function<void(const UserPtr &)> & on_change, ext::scope_guard * subscription) const;
UserPtr getUser(const String & user_name, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const;
UserPtr getUser(const UUID & user_id, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const;
UserPtr authorizeAndGetUser(const String & user_name, const String & password, const Poco::Net::IPAddress & address, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const;
UserPtr authorizeAndGetUser(const UUID & user_id, const String & password, const Poco::Net::IPAddress & address, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const;
std::shared_ptr<const AccessRightsContext> getAccessRightsContext(const ClientInfo & client_info, const AccessRights & granted_to_user, const Settings & settings, const String & current_database);
std::shared_ptr<const AccessRightsContext> getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database);
std::shared_ptr<QuotaContext>
createQuotaContext(const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key);

View File

@ -1,9 +1,9 @@
#include <Access/AccessRights.h>
#include <Common/Exception.h>
#include <common/logger_useful.h>
#include <boost/range/adaptor/map.hpp>
#include <unordered_map>
namespace DB
{
namespace ErrorCodes
@ -73,6 +73,7 @@ public:
inherited_access = src.inherited_access;
explicit_grants = src.explicit_grants;
partial_revokes = src.partial_revokes;
raw_access = src.raw_access;
access = src.access;
min_access = src.min_access;
max_access = src.max_access;
@ -114,8 +115,12 @@ public:
access_to_grant = grantable;
}
explicit_grants |= access_to_grant - partial_revokes;
partial_revokes -= access_to_grant;
AccessFlags new_explicit_grants = access_to_grant - partial_revokes;
if (level == TABLE_LEVEL)
removeExplicitGrantsRec(new_explicit_grants);
removePartialRevokesRec(access_to_grant);
explicit_grants |= new_explicit_grants;
calculateAllAccessRec(helper);
}
@ -147,16 +152,27 @@ public:
{
if constexpr (mode == NORMAL_REVOKE_MODE)
{
explicit_grants -= access_to_revoke;
if (level == TABLE_LEVEL)
removeExplicitGrantsRec(access_to_revoke);
else
removeExplicitGrants(access_to_revoke);
}
else if constexpr (mode == PARTIAL_REVOKE_MODE)
{
partial_revokes |= access_to_revoke - explicit_grants;
explicit_grants -= access_to_revoke;
AccessFlags new_partial_revokes = access_to_revoke - explicit_grants;
if (level == TABLE_LEVEL)
removeExplicitGrantsRec(access_to_revoke);
else
removeExplicitGrants(access_to_revoke);
removePartialRevokesRec(new_partial_revokes);
partial_revokes |= new_partial_revokes;
}
else /// mode == FULL_REVOKE_MODE
{
fullRevokeRec(access_to_revoke);
AccessFlags new_partial_revokes = access_to_revoke - explicit_grants;
removeExplicitGrantsRec(access_to_revoke);
removePartialRevokesRec(new_partial_revokes);
partial_revokes |= new_partial_revokes;
}
calculateAllAccessRec(helper);
}
@ -272,6 +288,24 @@ public:
calculateAllAccessRec(helper);
}
void traceTree(Poco::Logger * log) const
{
LOG_TRACE(log, "Tree(" << level << "): name=" << (node_name ? *node_name : "NULL")
<< ", explicit_grants=" << explicit_grants.toString()
<< ", partial_revokes=" << partial_revokes.toString()
<< ", inherited_access=" << inherited_access.toString()
<< ", raw_access=" << raw_access.toString()
<< ", access=" << access.toString()
<< ", min_access=" << min_access.toString()
<< ", max_access=" << max_access.toString()
<< ", num_children=" << (children ? children->size() : 0));
if (children)
{
for (auto & child : *children | boost::adaptors::map_values)
child.traceTree(log);
}
}
private:
Node * tryGetChild(const std::string_view & name)
{
@ -371,14 +405,28 @@ private:
calculateMinAndMaxAccess();
}
void fullRevokeRec(const AccessFlags & access_to_revoke)
void removeExplicitGrants(const AccessFlags & change)
{
explicit_grants -= access_to_revoke;
partial_revokes |= access_to_revoke;
explicit_grants -= change;
}
void removeExplicitGrantsRec(const AccessFlags & change)
{
removeExplicitGrants(change);
if (children)
{
for (auto & child : *children | boost::adaptors::map_values)
child.fullRevokeRec(access_to_revoke);
child.removeExplicitGrantsRec(change);
}
}
void removePartialRevokesRec(const AccessFlags & change)
{
partial_revokes -= change;
if (children)
{
for (auto & child : *children | boost::adaptors::map_values)
child.removePartialRevokesRec(change);
}
}
@ -726,4 +774,13 @@ void AccessRights::merge(const AccessRights & other)
}
}
void AccessRights::traceTree() const
{
auto * log = &Poco::Logger::get("AccessRights");
if (root)
root->traceTree(log);
else
LOG_TRACE(log, "Tree: NULL");
}
}

View File

@ -130,6 +130,8 @@ private:
template <typename... Args>
AccessFlags getAccessImpl(const Args &... args) const;
void traceTree() const;
struct Node;
std::unique_ptr<Node> root;
};

View File

@ -1,4 +1,5 @@
#include <Access/AccessRightsContext.h>
#include <Access/User.h>
#include <Common/Exception.h>
#include <Common/quoteString.h>
#include <Core/Settings.h>
@ -88,24 +89,23 @@ AccessRightsContext::AccessRightsContext()
}
AccessRightsContext::AccessRightsContext(const ClientInfo & client_info_, const AccessRights & granted_to_user_, const Settings & settings, const String & current_database_)
: user_name(client_info_.current_user)
, granted_to_user(granted_to_user_)
AccessRightsContext::AccessRightsContext(const UserPtr & user_, const ClientInfo & client_info_, const Settings & settings, const String & current_database_)
: user(user_)
, readonly(settings.readonly)
, allow_ddl(settings.allow_ddl)
, allow_introspection(settings.allow_introspection_functions)
, current_database(current_database_)
, interface(client_info_.interface)
, http_method(client_info_.http_method)
, trace_log(&Poco::Logger::get("AccessRightsContext (" + user_name + ")"))
, trace_log(&Poco::Logger::get("AccessRightsContext (" + user_->getName() + ")"))
{
}
template <int mode, typename... Args>
template <int mode, bool grant_option, typename... Args>
bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const
{
auto result_access = calculateResultAccess();
auto result_access = calculateResultAccess(grant_option);
bool is_granted = result_access->isGranted(access, args...);
if (trace_log)
@ -126,12 +126,21 @@ bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & acc
auto show_error = [&](const String & msg, [[maybe_unused]] int error_code)
{
if constexpr (mode == THROW_IF_ACCESS_DENIED)
throw Exception(msg, error_code);
throw Exception(user->getName() + ": " + msg, error_code);
else if constexpr (mode == LOG_WARNING_IF_ACCESS_DENIED)
LOG_WARNING(log_, msg + formatSkippedMessage(args...));
LOG_WARNING(log_, user->getName() + ": " + msg + formatSkippedMessage(args...));
};
if (readonly && calculateResultAccess(false, allow_ddl, allow_introspection)->isGranted(access, args...))
if (grant_option && calculateResultAccess(false, readonly, allow_ddl, allow_introspection)->isGranted(access, args...))
{
show_error(
"Not enough privileges. "
"The required privileges have been granted, but without grant option. "
"To execute this query it's necessary to have the grant "
+ AccessRightsElement{access, args...}.toString() + " WITH GRANT OPTION",
ErrorCodes::ACCESS_DENIED);
}
else if (readonly && calculateResultAccess(false, false, allow_ddl, allow_introspection)->isGranted(access, args...))
{
if (interface == ClientInfo::Interface::HTTP && http_method == ClientInfo::HTTPMethod::GET)
show_error(
@ -141,108 +150,116 @@ bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & acc
else
show_error("Cannot execute query in readonly mode", ErrorCodes::READONLY);
}
else if (!allow_ddl && calculateResultAccess(readonly, true, allow_introspection)->isGranted(access, args...))
else if (!allow_ddl && calculateResultAccess(false, readonly, true, allow_introspection)->isGranted(access, args...))
{
show_error("Cannot execute query. DDL queries are prohibited for the user", ErrorCodes::QUERY_IS_PROHIBITED);
}
else if (!allow_introspection && calculateResultAccess(readonly, allow_ddl, true)->isGranted(access, args...))
else if (!allow_introspection && calculateResultAccess(false, readonly, allow_ddl, true)->isGranted(access, args...))
{
show_error("Introspection functions are disabled, because setting 'allow_introspection_functions' is set to 0", ErrorCodes::FUNCTION_NOT_ALLOWED);
}
else
{
show_error(
user_name + ": Not enough privileges. To perform this operation you should have grant "
+ AccessRightsElement{access, args...}.toString(),
"Not enough privileges. To execute this query it's necessary to have the grant "
+ AccessRightsElement{access, args...}.toString() + (grant_option ? " WITH GRANT OPTION" : ""),
ErrorCodes::ACCESS_DENIED);
}
return false;
}
template <int mode>
template <int mode, bool grant_option>
bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessRightsElement & element) const
{
if (element.any_database)
{
return checkImpl<mode>(log_, element.access_flags);
return checkImpl<mode, grant_option>(log_, element.access_flags);
}
else if (element.any_table)
{
if (element.database.empty())
return checkImpl<mode>(log_, element.access_flags, current_database);
return checkImpl<mode, grant_option>(log_, element.access_flags, current_database);
else
return checkImpl<mode>(log_, element.access_flags, element.database);
return checkImpl<mode, grant_option>(log_, element.access_flags, element.database);
}
else if (element.any_column)
{
if (element.database.empty())
return checkImpl<mode>(log_, element.access_flags, current_database, element.table);
return checkImpl<mode, grant_option>(log_, element.access_flags, current_database, element.table);
else
return checkImpl<mode>(log_, element.access_flags, element.database, element.table);
return checkImpl<mode, grant_option>(log_, element.access_flags, element.database, element.table);
}
else
{
if (element.database.empty())
return checkImpl<mode>(log_, element.access_flags, current_database, element.table, element.columns);
return checkImpl<mode, grant_option>(log_, element.access_flags, current_database, element.table, element.columns);
else
return checkImpl<mode>(log_, element.access_flags, element.database, element.table, element.columns);
return checkImpl<mode, grant_option>(log_, element.access_flags, element.database, element.table, element.columns);
}
}
template <int mode>
template <int mode, bool grant_option>
bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessRightsElements & elements) const
{
for (const auto & element : elements)
if (!checkImpl<mode>(log_, element))
if (!checkImpl<mode, grant_option>(log_, element))
return false;
return true;
}
void AccessRightsContext::check(const AccessFlags & access) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database, table); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database, table, column); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database, table, columns); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access, database, table, columns); }
void AccessRightsContext::check(const AccessRightsElement & access) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access); }
void AccessRightsContext::check(const AccessRightsElements & access) const { checkImpl<THROW_IF_ACCESS_DENIED>(nullptr, access); }
void AccessRightsContext::check(const AccessFlags & access) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, column); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
void AccessRightsContext::check(const AccessRightsElement & access) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); }
void AccessRightsContext::check(const AccessRightsElements & access) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessFlags & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database, table); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database, table, column); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database, table, columns); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access, database, table, columns); }
bool AccessRightsContext::isGranted(const AccessRightsElement & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessRightsElements & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessFlags & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, column); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
bool AccessRightsContext::isGranted(const AccessRightsElement & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessRightsElements & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database, table); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database, table, column); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database, table, columns); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access, database, table, columns); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElement & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElements & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED>(log_, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, column); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, columns); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, columns); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElement & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElements & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, column); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, columns); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, columns); }
void AccessRightsContext::checkGrantOption(const AccessRightsElement & access) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); }
void AccessRightsContext::checkGrantOption(const AccessRightsElements & access) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); }
boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess() const
boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess(bool grant_option) const
{
auto res = result_access_cache[0].load();
if (res)
return res;
return calculateResultAccess(readonly, allow_ddl, allow_introspection);
return calculateResultAccess(grant_option, readonly, allow_ddl, allow_introspection);
}
boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess(UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const
boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const
{
size_t cache_index = static_cast<size_t>(readonly_ != readonly)
+ static_cast<size_t>(allow_ddl_ != allow_ddl) * 2 +
+ static_cast<size_t>(allow_introspection_ != allow_introspection) * 3;
+ static_cast<size_t>(allow_introspection_ != allow_introspection) * 3
+ static_cast<size_t>(grant_option) * 4;
assert(cache_index < std::size(result_access_cache));
auto cached = result_access_cache[cache_index].load();
if (cached)
@ -256,7 +273,7 @@ boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess
auto result_ptr = boost::make_shared<AccessRights>();
auto & result = *result_ptr;
result = granted_to_user;
result = grant_option ? user->access_with_grant_option : user->access;
static const AccessFlags table_ddl = AccessType::CREATE_DATABASE | AccessType::CREATE_TABLE | AccessType::CREATE_VIEW
| AccessType::ALTER_TABLE | AccessType::ALTER_VIEW | AccessType::DROP_DATABASE | AccessType::DROP_TABLE | AccessType::DROP_VIEW
@ -265,12 +282,18 @@ boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess
static const AccessFlags table_and_dictionary_ddl = table_ddl | dictionary_ddl;
static const AccessFlags write_table_access = AccessType::INSERT | AccessType::OPTIMIZE;
/// Anyone has access to the "system" database.
result.grant(AccessType::SELECT, "system");
if (readonly_)
result.fullRevoke(write_table_access | AccessType::SYSTEM);
if (readonly_ || !allow_ddl_)
result.fullRevoke(table_and_dictionary_ddl);
if (readonly_ && grant_option)
result.fullRevoke(AccessType::ALL);
if (readonly_ == 1)
{
/// Table functions are forbidden in readonly mode.
@ -282,7 +305,11 @@ boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess
result.fullRevoke(AccessType::INTROSPECTION);
result_access_cache[cache_index].store(result_ptr);
return std::move(result_ptr);
if (trace_log && (readonly == readonly_) && (allow_ddl == allow_ddl_) && (allow_introspection == allow_introspection_))
LOG_TRACE(trace_log, "List of all grants: " << result_ptr->toString() << (grant_option ? " WITH GRANT OPTION" : ""));
return result_ptr;
}
}

View File

@ -11,6 +11,8 @@ namespace Poco { class Logger; }
namespace DB
{
struct Settings;
struct User;
using UserPtr = std::shared_ptr<const User>;
class AccessRightsContext
@ -19,7 +21,7 @@ public:
/// Default constructor creates access rights' context which allows everything.
AccessRightsContext();
AccessRightsContext(const ClientInfo & client_info_, const AccessRights & granted_to_user, const Settings & settings, const String & current_database_);
AccessRightsContext(const UserPtr & user_, const ClientInfo & client_info_, const Settings & settings, const String & current_database_);
/// Checks if a specified access granted, and throws an exception if not.
/// Empty database means the current database.
@ -52,21 +54,30 @@ public:
bool isGranted(Poco::Logger * log_, const AccessRightsElement & access) const;
bool isGranted(Poco::Logger * log_, const AccessRightsElements & access) const;
/// Checks if a specified access granted with grant option, and throws an exception if not.
void checkGrantOption(const AccessFlags & access) const;
void checkGrantOption(const AccessFlags & access, const std::string_view & database) const;
void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const;
void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const;
void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const;
void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const;
void checkGrantOption(const AccessRightsElement & access) const;
void checkGrantOption(const AccessRightsElements & access) const;
private:
template <int mode, typename... Args>
template <int mode, bool grant_option, typename... Args>
bool checkImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const;
template <int mode>
template <int mode, bool grant_option>
bool checkImpl(Poco::Logger * log_, const AccessRightsElement & access) const;
template <int mode>
template <int mode, bool grant_option>
bool checkImpl(Poco::Logger * log_, const AccessRightsElements & access) const;
boost::shared_ptr<const AccessRights> calculateResultAccess() const;
boost::shared_ptr<const AccessRights> calculateResultAccess(UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const;
boost::shared_ptr<const AccessRights> calculateResultAccess(bool grant_option) const;
boost::shared_ptr<const AccessRights> calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const;
const String user_name;
const AccessRights granted_to_user;
const UserPtr user;
const UInt64 readonly = 0;
const bool allow_ddl = true;
const bool allow_introspection = true;
@ -74,7 +85,7 @@ private:
const ClientInfo::Interface interface = ClientInfo::Interface::TCP;
const ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN;
Poco::Logger * const trace_log = nullptr;
mutable boost::atomic_shared_ptr<const AccessRights> result_access_cache[4];
mutable boost::atomic_shared_ptr<const AccessRights> result_access_cache[7];
mutable std::mutex mutex;
};

View File

@ -1,15 +1,12 @@
#include <Access/AllowedClientHosts.h>
#include <Common/Exception.h>
#include <common/SimpleCache.h>
#include <Common/StringUtils/StringUtils.h>
#include <IO/ReadHelpers.h>
#include <Functions/likePatternToRegexp.h>
#include <Poco/Net/SocketAddress.h>
#include <Poco/RegularExpression.h>
#include <common/logger_useful.h>
#include <ext/scope_guard.h>
#include <boost/range/algorithm/find.hpp>
#include <boost/range/algorithm/find_first_of.hpp>
#include <boost/algorithm/string/predicate.hpp>
#include <boost/algorithm/string/replace.hpp>
#include <ifaddrs.h>
@ -27,20 +24,6 @@ namespace
using IPSubnet = AllowedClientHosts::IPSubnet;
const IPSubnet ALL_ADDRESSES{IPAddress{IPAddress::IPv6}, IPAddress{IPAddress::IPv6}};
const IPAddress & getIPV6Loopback()
{
static const IPAddress ip("::1");
return ip;
}
bool isIPV4LoopbackMappedToIPV6(const IPAddress & ip)
{
static const IPAddress prefix("::ffff:127.0.0.0");
/// 104 == 128 - 24, we have to reset the lowest 24 bits of 128 before comparing with `prefix`
/// (IPv4 loopback means any IP from 127.0.0.0 to 127.255.255.255).
return (ip & IPAddress(104, IPAddress::IPv6)) == prefix;
}
/// Converts an address to IPv6.
/// The loopback address "127.0.0.1" (or any "127.x.y.z") is converted to "::1".
IPAddress toIPv6(const IPAddress & ip)
@ -52,35 +35,18 @@ namespace
v6 = IPAddress("::ffff:" + ip.toString());
// ::ffff:127.XX.XX.XX -> ::1
if (isIPV4LoopbackMappedToIPV6(v6))
v6 = getIPV6Loopback();
if ((v6 & IPAddress(104, IPAddress::IPv6)) == IPAddress("::ffff:127.0.0.0"))
v6 = IPAddress{"::1"};
return v6;
}
/// Converts a subnet to IPv6.
IPSubnet toIPv6(const IPSubnet & subnet)
{
IPSubnet v6;
if (subnet.prefix.family() == IPAddress::IPv6)
v6.prefix = subnet.prefix;
else
v6.prefix = IPAddress("::ffff:" + subnet.prefix.toString());
if (subnet.mask.family() == IPAddress::IPv6)
v6.mask = subnet.mask;
else
v6.mask = IPAddress(96, IPAddress::IPv6) | IPAddress("::ffff:" + subnet.mask.toString());
v6.prefix = v6.prefix & v6.mask;
// ::ffff:127.XX.XX.XX -> ::1
if (isIPV4LoopbackMappedToIPV6(v6.prefix))
v6 = {getIPV6Loopback(), IPAddress(128, IPAddress::IPv6)};
return v6;
return IPSubnet(toIPv6(subnet.getPrefix()), subnet.getMask());
}
/// Helper function for isAddressOfHost().
bool isAddressOfHostImpl(const IPAddress & address, const String & host)
{
@ -150,7 +116,7 @@ namespace
int err = getifaddrs(&ifa_begin);
if (err)
return {getIPV6Loopback()};
return {IPAddress{"::1"}};
for (const ifaddrs * ifa = ifa_begin; ifa; ifa = ifa->ifa_next)
{
@ -203,163 +169,203 @@ namespace
static SimpleCache<decltype(getHostByAddressImpl), &getHostByAddressImpl> cache;
return cache(address);
}
}
String AllowedClientHosts::IPSubnet::toString() const
{
unsigned int prefix_length = mask.prefixLength();
if (IPAddress{prefix_length, mask.family()} == mask)
return prefix.toString() + "/" + std::to_string(prefix_length);
return prefix.toString() + "/" + mask.toString();
}
AllowedClientHosts::AllowedClientHosts()
{
}
AllowedClientHosts::AllowedClientHosts(AllAddressesTag)
{
addAllAddresses();
}
AllowedClientHosts::~AllowedClientHosts() = default;
AllowedClientHosts::AllowedClientHosts(const AllowedClientHosts & src)
{
*this = src;
}
AllowedClientHosts & AllowedClientHosts::operator =(const AllowedClientHosts & src)
{
addresses = src.addresses;
localhost = src.localhost;
subnets = src.subnets;
host_names = src.host_names;
host_regexps = src.host_regexps;
compiled_host_regexps.clear();
return *this;
}
AllowedClientHosts::AllowedClientHosts(AllowedClientHosts && src) = default;
AllowedClientHosts & AllowedClientHosts::operator =(AllowedClientHosts && src) = default;
void AllowedClientHosts::clear()
{
addresses.clear();
localhost = false;
subnets.clear();
host_names.clear();
host_regexps.clear();
compiled_host_regexps.clear();
}
bool AllowedClientHosts::empty() const
{
return addresses.empty() && subnets.empty() && host_names.empty() && host_regexps.empty();
}
void AllowedClientHosts::addAddress(const IPAddress & address)
{
IPAddress addr_v6 = toIPv6(address);
if (boost::range::find(addresses, addr_v6) != addresses.end())
return;
addresses.push_back(addr_v6);
if (addr_v6.isLoopback())
localhost = true;
}
void AllowedClientHosts::addAddress(const String & address)
{
addAddress(IPAddress{address});
}
void AllowedClientHosts::addSubnet(const IPSubnet & subnet)
{
IPSubnet subnet_v6 = toIPv6(subnet);
if (subnet_v6.mask == IPAddress(128, IPAddress::IPv6))
void parseLikePatternIfIPSubnet(const String & pattern, IPSubnet & subnet, IPAddress::Family address_family)
{
addAddress(subnet_v6.prefix);
return;
size_t slash = pattern.find('/');
if (slash != String::npos)
{
/// IP subnet, e.g. "192.168.0.0/16" or "192.168.0.0/255.255.0.0".
subnet = IPSubnet{pattern};
return;
}
bool has_wildcard = (pattern.find_first_of("%_") != String::npos);
if (has_wildcard)
{
/// IP subnet specified with one of the wildcard characters, e.g. "192.168.%.%".
String wildcard_replaced_with_zero_bits = pattern;
String wildcard_replaced_with_one_bits = pattern;
if (address_family == IPAddress::IPv6)
{
boost::algorithm::replace_all(wildcard_replaced_with_zero_bits, "_", "0");
boost::algorithm::replace_all(wildcard_replaced_with_zero_bits, "%", "0000");
boost::algorithm::replace_all(wildcard_replaced_with_one_bits, "_", "f");
boost::algorithm::replace_all(wildcard_replaced_with_one_bits, "%", "ffff");
}
else if (address_family == IPAddress::IPv4)
{
boost::algorithm::replace_all(wildcard_replaced_with_zero_bits, "%", "0");
boost::algorithm::replace_all(wildcard_replaced_with_one_bits, "%", "255");
}
IPAddress prefix{wildcard_replaced_with_zero_bits};
IPAddress mask = ~(prefix ^ IPAddress{wildcard_replaced_with_one_bits});
subnet = IPSubnet{prefix, mask};
return;
}
/// Exact IP address.
subnet = IPSubnet{pattern};
}
if (boost::range::find(subnets, subnet_v6) == subnets.end())
subnets.push_back(subnet_v6);
}
void AllowedClientHosts::addSubnet(const IPAddress & prefix, const IPAddress & mask)
{
addSubnet(IPSubnet{prefix, mask});
}
void AllowedClientHosts::addSubnet(const IPAddress & prefix, size_t num_prefix_bits)
{
addSubnet(prefix, IPAddress(num_prefix_bits, prefix.family()));
}
void AllowedClientHosts::addSubnet(const String & subnet)
{
size_t slash = subnet.find('/');
if (slash == String::npos)
/// Extracts a subnet, a host name or a host name regular expession from a like pattern.
void parseLikePattern(
const String & pattern, std::optional<IPSubnet> & subnet, std::optional<String> & name, std::optional<String> & name_regexp)
{
addAddress(subnet);
return;
/// If `host` starts with digits and a dot then it's an IP pattern, otherwise it's a hostname pattern.
size_t first_not_digit = pattern.find_first_not_of("0123456789");
if ((first_not_digit != String::npos) && (first_not_digit != 0) && (pattern[first_not_digit] == '.'))
{
parseLikePatternIfIPSubnet(pattern, subnet.emplace(), IPAddress::IPv4);
return;
}
size_t first_not_hex = pattern.find_first_not_of("0123456789ABCDEFabcdef");
if (((first_not_hex == 4) && pattern[first_not_hex] == ':') || pattern.starts_with("::"))
{
parseLikePatternIfIPSubnet(pattern, subnet.emplace(), IPAddress::IPv6);
return;
}
bool has_wildcard = (pattern.find_first_of("%_") != String::npos);
if (has_wildcard)
{
name_regexp = likePatternToRegexp(pattern);
return;
}
name = pattern;
}
IPAddress prefix{String{subnet, 0, slash}};
String mask(subnet, slash + 1, subnet.length() - slash - 1);
if (std::all_of(mask.begin(), mask.end(), isNumericASCII))
addSubnet(prefix, parseFromString<UInt8>(mask));
else
addSubnet(prefix, IPAddress{mask});
}
void AllowedClientHosts::addHostName(const String & host_name)
bool AllowedClientHosts::contains(const IPAddress & client_address) const
{
if (boost::range::find(host_names, host_name) != host_names.end())
return;
host_names.push_back(host_name);
if (boost::iequals(host_name, "localhost"))
localhost = true;
}
if (any_host)
return true;
IPAddress client_v6 = toIPv6(client_address);
void AllowedClientHosts::addHostRegexp(const String & host_regexp)
{
if (boost::range::find(host_regexps, host_regexp) == host_regexps.end())
host_regexps.push_back(host_regexp);
}
std::optional<bool> is_client_local_value;
auto is_client_local = [&]
{
if (is_client_local_value)
return *is_client_local_value;
is_client_local_value = isAddressOfLocalhost(client_v6);
return *is_client_local_value;
};
if (local_host && is_client_local())
return true;
void AllowedClientHosts::addAllAddresses()
{
clear();
addSubnet(ALL_ADDRESSES);
}
/// Check `addresses`.
auto check_address = [&](const IPAddress & address_)
{
IPAddress address_v6 = toIPv6(address_);
if (address_v6.isLoopback())
return is_client_local();
return address_v6 == client_v6;
};
for (const auto & address : addresses)
if (check_address(address))
return true;
bool AllowedClientHosts::containsAllAddresses() const
{
return (boost::range::find(subnets, ALL_ADDRESSES) != subnets.end())
|| (boost::range::find(host_regexps, ".*") != host_regexps.end())
|| (boost::range::find(host_regexps, "$") != host_regexps.end());
/// Check `subnets`.
auto check_subnet = [&](const IPSubnet & subnet_)
{
IPSubnet subnet_v6 = toIPv6(subnet_);
if (subnet_v6.isMaskAllBitsOne())
return check_address(subnet_v6.getPrefix());
return (client_v6 & subnet_v6.getMask()) == subnet_v6.getPrefix();
};
for (const auto & subnet : subnets)
if (check_subnet(subnet))
return true;
/// Check `names`.
auto check_name = [&](const String & name_)
{
if (boost::iequals(name_, "localhost"))
return is_client_local();
try
{
return isAddressOfHost(client_v6, name_);
}
catch (const Exception & e)
{
if (e.code() != ErrorCodes::DNS_ERROR)
throw;
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
LOG_WARNING(
&Logger::get("AddressPatterns"),
"Failed to check if the allowed client hosts contain address " << client_address.toString() << ". " << e.displayText()
<< ", code = " << e.code());
return false;
}
};
for (const String & name : names)
if (check_name(name))
return true;
/// Check `name_regexps`.
std::optional<String> resolved_host;
auto check_name_regexp = [&](const String & name_regexp_)
{
try
{
if (boost::iequals(name_regexp_, "localhost"))
return is_client_local();
if (!resolved_host)
resolved_host = getHostByAddress(client_v6);
if (resolved_host->empty())
return false;
Poco::RegularExpression re(name_regexp_);
Poco::RegularExpression::Match match;
return re.match(*resolved_host, match) != 0;
}
catch (const Exception & e)
{
if (e.code() != ErrorCodes::DNS_ERROR)
throw;
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
LOG_WARNING(
&Logger::get("AddressPatterns"),
"Failed to check if the allowed client hosts contain address " << client_address.toString() << ". " << e.displayText()
<< ", code = " << e.code());
return false;
}
};
for (const String & name_regexp : name_regexps)
if (check_name_regexp(name_regexp))
return true;
auto check_like_pattern = [&](const String & pattern)
{
std::optional<IPSubnet> subnet;
std::optional<String> name;
std::optional<String> name_regexp;
parseLikePattern(pattern, subnet, name, name_regexp);
if (subnet)
return check_subnet(*subnet);
else if (name)
return check_name(*name);
else if (name_regexp)
return check_name_regexp(*name_regexp);
else
return false;
};
for (const String & like_pattern : like_patterns)
if (check_like_pattern(like_pattern))
return true;
return false;
}
@ -374,86 +380,4 @@ void AllowedClientHosts::checkContains(const IPAddress & address, const String &
}
}
bool AllowedClientHosts::contains(const IPAddress & address) const
{
/// Check `ip_addresses`.
IPAddress addr_v6 = toIPv6(address);
if (boost::range::find(addresses, addr_v6) != addresses.end())
return true;
if (localhost && isAddressOfLocalhost(addr_v6))
return true;
/// Check `ip_subnets`.
for (const auto & subnet : subnets)
if ((addr_v6 & subnet.mask) == subnet.prefix)
return true;
/// Check `hosts`.
for (const String & host_name : host_names)
{
try
{
if (isAddressOfHost(addr_v6, host_name))
return true;
}
catch (const Exception & e)
{
if (e.code() != ErrorCodes::DNS_ERROR)
throw;
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
LOG_WARNING(
&Logger::get("AddressPatterns"),
"Failed to check if the allowed client hosts contain address " << address.toString() << ". " << e.displayText()
<< ", code = " << e.code());
}
}
/// Check `host_regexps`.
try
{
String resolved_host = getHostByAddress(addr_v6);
if (!resolved_host.empty())
{
compileRegexps();
for (const auto & compiled_regexp : compiled_host_regexps)
{
Poco::RegularExpression::Match match;
if (compiled_regexp && compiled_regexp->match(resolved_host, match))
return true;
}
}
}
catch (const Exception & e)
{
if (e.code() != ErrorCodes::DNS_ERROR)
throw;
/// Try to ignore DNS errors: if host cannot be resolved, skip it and try next.
LOG_WARNING(
&Logger::get("AddressPatterns"),
"Failed to check if the allowed client hosts contain address " << address.toString() << ". " << e.displayText()
<< ", code = " << e.code());
}
return false;
}
void AllowedClientHosts::compileRegexps() const
{
if (compiled_host_regexps.size() == host_regexps.size())
return;
size_t old_size = compiled_host_regexps.size();
compiled_host_regexps.reserve(host_regexps.size());
for (size_t i = old_size; i != host_regexps.size(); ++i)
compiled_host_regexps.emplace_back(std::make_unique<Poco::RegularExpression>(host_regexps[i]));
}
bool operator ==(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs)
{
return (lhs.addresses == rhs.addresses) && (lhs.subnets == rhs.subnets) && (lhs.host_names == rhs.host_names)
&& (lhs.host_regexps == rhs.host_regexps);
}
}

View File

@ -4,12 +4,9 @@
#include <Poco/Net/IPAddress.h>
#include <memory>
#include <vector>
namespace Poco
{
class RegularExpression;
}
#include <boost/range/algorithm/find.hpp>
#include <boost/range/algorithm_ext/erase.hpp>
#include <boost/algorithm/string/predicate.hpp>
namespace DB
@ -20,69 +17,100 @@ class AllowedClientHosts
public:
using IPAddress = Poco::Net::IPAddress;
struct IPSubnet
class IPSubnet
{
IPAddress prefix;
IPAddress mask;
public:
IPSubnet() {}
IPSubnet(const IPAddress & prefix_, const IPAddress & mask_) { set(prefix_, mask_); }
IPSubnet(const IPAddress & prefix_, size_t num_prefix_bits) { set(prefix_, num_prefix_bits); }
explicit IPSubnet(const IPAddress & address) { set(address); }
explicit IPSubnet(const String & str);
const IPAddress & getPrefix() const { return prefix; }
const IPAddress & getMask() const { return mask; }
bool isMaskAllBitsOne() const;
String toString() const;
friend bool operator ==(const IPSubnet & lhs, const IPSubnet & rhs) { return (lhs.prefix == rhs.prefix) && (lhs.mask == rhs.mask); }
friend bool operator !=(const IPSubnet & lhs, const IPSubnet & rhs) { return !(lhs == rhs); }
private:
void set(const IPAddress & prefix_, const IPAddress & mask_);
void set(const IPAddress & prefix_, size_t num_prefix_bits);
void set(const IPAddress & address);
IPAddress prefix;
IPAddress mask;
};
struct AllAddressesTag {};
struct AnyHostTag {};
AllowedClientHosts();
explicit AllowedClientHosts(AllAddressesTag);
~AllowedClientHosts();
AllowedClientHosts() {}
explicit AllowedClientHosts(AnyHostTag) { addAnyHost(); }
~AllowedClientHosts() {}
AllowedClientHosts(const AllowedClientHosts & src);
AllowedClientHosts & operator =(const AllowedClientHosts & src);
AllowedClientHosts(AllowedClientHosts && src);
AllowedClientHosts & operator =(AllowedClientHosts && src);
AllowedClientHosts(const AllowedClientHosts & src) = default;
AllowedClientHosts & operator =(const AllowedClientHosts & src) = default;
AllowedClientHosts(AllowedClientHosts && src) = default;
AllowedClientHosts & operator =(AllowedClientHosts && src) = default;
/// Removes all contained addresses. This will disallow all addresses.
/// Removes all contained addresses. This will disallow all hosts.
void clear();
bool empty() const;
/// Allows exact IP address.
/// For example, 213.180.204.3 or 2a02:6b8::3
void addAddress(const IPAddress & address);
void addAddress(const String & address);
/// Allows an IP subnet.
void addSubnet(const IPSubnet & subnet);
void addSubnet(const String & subnet);
/// Allows an IP subnet.
/// For example, 312.234.1.1/255.255.255.0 or 2a02:6b8::3/FFFF:FFFF:FFFF:FFFF::
void addSubnet(const IPAddress & prefix, const IPAddress & mask);
/// Allows an IP subnet.
/// For example, 10.0.0.1/8 or 2a02:6b8::3/64
void addSubnet(const IPAddress & prefix, size_t num_prefix_bits);
/// Allows all addresses.
void addAllAddresses();
/// Allows an exact host. The `contains()` function will check that the provided address equals to one of that host's addresses.
void addHostName(const String & host_name);
/// Allows a regular expression for the host.
void addHostRegexp(const String & host_regexp);
void addAddress(const String & address) { addAddress(IPAddress(address)); }
void removeAddress(const IPAddress & address);
void removeAddress(const String & address) { removeAddress(IPAddress{address}); }
const std::vector<IPAddress> & getAddresses() const { return addresses; }
/// Allows an IP subnet.
/// For example, 312.234.1.1/255.255.255.0 or 2a02:6b8::3/64
void addSubnet(const IPSubnet & subnet);
void addSubnet(const String & subnet) { addSubnet(IPSubnet{subnet}); }
void addSubnet(const IPAddress & prefix, const IPAddress & mask) { addSubnet({prefix, mask}); }
void addSubnet(const IPAddress & prefix, size_t num_prefix_bits) { addSubnet({prefix, num_prefix_bits}); }
void removeSubnet(const IPSubnet & subnet);
void removeSubnet(const String & subnet) { removeSubnet(IPSubnet{subnet}); }
void removeSubnet(const IPAddress & prefix, const IPAddress & mask) { removeSubnet({prefix, mask}); }
void removeSubnet(const IPAddress & prefix, size_t num_prefix_bits) { removeSubnet({prefix, num_prefix_bits}); }
const std::vector<IPSubnet> & getSubnets() const { return subnets; }
const std::vector<String> & getHostNames() const { return host_names; }
const std::vector<String> & getHostRegexps() const { return host_regexps; }
/// Allows an exact host name. The `contains()` function will check that the provided address equals to one of that host's addresses.
void addName(const String & name);
void removeName(const String & name);
const std::vector<String> & getNames() const { return names; }
/// Allows the host names matching a regular expression.
void addNameRegexp(const String & name_regexp);
void removeNameRegexp(const String & name_regexp);
const std::vector<String> & getNameRegexps() const { return name_regexps; }
/// Allows IP addresses or host names using LIKE pattern.
/// This pattern can contain % and _ wildcard characters.
/// For example, addLikePattern("@") will allow all addresses.
void addLikePattern(const String & pattern);
void removeLikePattern(const String & like_pattern);
const std::vector<String> & getLikePatterns() const { return like_patterns; }
/// Allows local host.
void addLocalHost();
void removeLocalHost();
bool containsLocalHost() const { return local_host;}
/// Allows any host.
void addAnyHost();
bool containsAnyHost() const { return any_host;}
void add(const AllowedClientHosts & other);
void remove(const AllowedClientHosts & other);
/// Checks if the provided address is in the list. Returns false if not.
bool contains(const IPAddress & address) const;
/// Checks if any address is allowed.
bool containsAllAddresses() const;
/// Checks if the provided address is in the list. Throws an exception if not.
/// `username` is only used for generating an error message if the address isn't in the list.
void checkContains(const IPAddress & address, const String & user_name = String()) const;
@ -91,13 +119,269 @@ public:
friend bool operator !=(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs) { return !(lhs == rhs); }
private:
void compileRegexps() const;
std::vector<IPAddress> addresses;
bool localhost = false;
std::vector<IPSubnet> subnets;
std::vector<String> host_names;
std::vector<String> host_regexps;
mutable std::vector<std::unique_ptr<Poco::RegularExpression>> compiled_host_regexps;
Strings names;
Strings name_regexps;
Strings like_patterns;
bool any_host = false;
bool local_host = false;
};
inline void AllowedClientHosts::IPSubnet::set(const IPAddress & prefix_, const IPAddress & mask_)
{
prefix = prefix_;
mask = mask_;
if (prefix.family() != mask.family())
{
if (prefix.family() == IPAddress::IPv4)
prefix = IPAddress("::ffff:" + prefix.toString());
if (mask.family() == IPAddress::IPv4)
mask = IPAddress(96, IPAddress::IPv6) | IPAddress("::ffff:" + mask.toString());
}
prefix = prefix & mask;
if (prefix.family() == IPAddress::IPv4)
{
if ((prefix & IPAddress{8, IPAddress::IPv4}) == IPAddress{"127.0.0.0"})
{
// 127.XX.XX.XX -> 127.0.0.1
prefix = IPAddress{"127.0.0.1"};
mask = IPAddress{32, IPAddress::IPv4};
}
}
else
{
if ((prefix & IPAddress{104, IPAddress::IPv6}) == IPAddress{"::ffff:127.0.0.0"})
{
// ::ffff:127.XX.XX.XX -> ::1
prefix = IPAddress{"::1"};
mask = IPAddress{128, IPAddress::IPv6};
}
}
}
inline void AllowedClientHosts::IPSubnet::set(const IPAddress & prefix_, size_t num_prefix_bits)
{
set(prefix_, IPAddress(num_prefix_bits, prefix_.family()));
}
inline void AllowedClientHosts::IPSubnet::set(const IPAddress & address)
{
set(address, address.length() * 8);
}
inline AllowedClientHosts::IPSubnet::IPSubnet(const String & str)
{
size_t slash = str.find('/');
if (slash == String::npos)
{
set(IPAddress(str));
return;
}
IPAddress new_prefix{String{str, 0, slash}};
String mask_str(str, slash + 1, str.length() - slash - 1);
bool only_digits = (mask_str.find_first_not_of("0123456789") == std::string::npos);
if (only_digits)
set(new_prefix, std::stoul(mask_str));
else
set(new_prefix, IPAddress{mask_str});
}
inline String AllowedClientHosts::IPSubnet::toString() const
{
unsigned int prefix_length = mask.prefixLength();
if (isMaskAllBitsOne())
return prefix.toString();
else if (IPAddress{prefix_length, mask.family()} == mask)
return prefix.toString() + "/" + std::to_string(prefix_length);
else
return prefix.toString() + "/" + mask.toString();
}
inline bool AllowedClientHosts::IPSubnet::isMaskAllBitsOne() const
{
return mask == IPAddress(mask.length() * 8, mask.family());
}
inline void AllowedClientHosts::clear()
{
addresses = {};
subnets = {};
names = {};
name_regexps = {};
like_patterns = {};
any_host = false;
local_host = false;
}
inline bool AllowedClientHosts::empty() const
{
return !any_host && !local_host && addresses.empty() && subnets.empty() && names.empty() && name_regexps.empty() && like_patterns.empty();
}
inline void AllowedClientHosts::addAddress(const IPAddress & address)
{
if (address.isLoopback())
local_host = true;
else if (boost::range::find(addresses, address) == addresses.end())
addresses.push_back(address);
}
inline void AllowedClientHosts::removeAddress(const IPAddress & address)
{
if (address.isLoopback())
local_host = false;
else
boost::range::remove_erase(addresses, address);
}
inline void AllowedClientHosts::addSubnet(const IPSubnet & subnet)
{
if (subnet.getMask().isWildcard())
any_host = true;
else if (subnet.isMaskAllBitsOne())
addAddress(subnet.getPrefix());
else if (boost::range::find(subnets, subnet) == subnets.end())
subnets.push_back(subnet);
}
inline void AllowedClientHosts::removeSubnet(const IPSubnet & subnet)
{
if (subnet.getMask().isWildcard())
any_host = false;
else if (subnet.isMaskAllBitsOne())
removeAddress(subnet.getPrefix());
else
boost::range::remove_erase(subnets, subnet);
}
inline void AllowedClientHosts::addName(const String & name)
{
if (boost::iequals(name, "localhost"))
local_host = true;
else if (boost::range::find(names, name) == names.end())
names.push_back(name);
}
inline void AllowedClientHosts::removeName(const String & name)
{
if (boost::iequals(name, "localhost"))
local_host = false;
else
boost::range::remove_erase(names, name);
}
inline void AllowedClientHosts::addNameRegexp(const String & name_regexp)
{
if (boost::iequals(name_regexp, "localhost"))
local_host = true;
else if (name_regexp == ".*")
any_host = true;
else if (boost::range::find(name_regexps, name_regexp) == name_regexps.end())
name_regexps.push_back(name_regexp);
}
inline void AllowedClientHosts::removeNameRegexp(const String & name_regexp)
{
if (boost::iequals(name_regexp, "localhost"))
local_host = false;
else if (name_regexp == ".*")
any_host = false;
else
boost::range::remove_erase(name_regexps, name_regexp);
}
inline void AllowedClientHosts::addLikePattern(const String & pattern)
{
if (boost::iequals(pattern, "localhost") || (pattern == "127.0.0.1") || (pattern == "::1"))
local_host = true;
else if ((pattern == "@") || (pattern == "0.0.0.0/0") || (pattern == "::/0"))
any_host = true;
else if (boost::range::find(like_patterns, pattern) == name_regexps.end())
like_patterns.push_back(pattern);
}
inline void AllowedClientHosts::removeLikePattern(const String & pattern)
{
if (boost::iequals(pattern, "localhost") || (pattern == "127.0.0.1") || (pattern == "::1"))
local_host = false;
else if ((pattern == "@") || (pattern == "0.0.0.0/0") || (pattern == "::/0"))
any_host = false;
else
boost::range::remove_erase(like_patterns, pattern);
}
inline void AllowedClientHosts::addLocalHost()
{
local_host = true;
}
inline void AllowedClientHosts::removeLocalHost()
{
local_host = false;
}
inline void AllowedClientHosts::addAnyHost()
{
clear();
any_host = true;
}
inline void AllowedClientHosts::add(const AllowedClientHosts & other)
{
if (other.containsAnyHost())
{
addAnyHost();
return;
}
if (other.containsLocalHost())
addLocalHost();
for (const IPAddress & address : other.getAddresses())
addAddress(address);
for (const IPSubnet & subnet : other.getSubnets())
addSubnet(subnet);
for (const String & name : other.getNames())
addName(name);
for (const String & name_regexp : other.getNameRegexps())
addNameRegexp(name_regexp);
for (const String & like_pattern : other.getLikePatterns())
addLikePattern(like_pattern);
}
inline void AllowedClientHosts::remove(const AllowedClientHosts & other)
{
if (other.containsAnyHost())
{
clear();
return;
}
if (other.containsLocalHost())
removeLocalHost();
for (const IPAddress & address : other.getAddresses())
removeAddress(address);
for (const IPSubnet & subnet : other.getSubnets())
removeSubnet(subnet);
for (const String & name : other.getNames())
removeName(name);
for (const String & name_regexp : other.getNameRegexps())
removeNameRegexp(name_regexp);
for (const String & like_pattern : other.getLikePatterns())
removeLikePattern(like_pattern);
}
inline bool operator ==(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs)
{
return (lhs.any_host == rhs.any_host) && (lhs.local_host == rhs.local_host) && (lhs.addresses == rhs.addresses)
&& (lhs.subnets == rhs.subnets) && (lhs.names == rhs.names) && (lhs.name_regexps == rhs.name_regexps)
&& (lhs.like_patterns == rhs.like_patterns);
}
}

View File

@ -1,166 +1,18 @@
#include <Access/Authentication.h>
#include <Common/Exception.h>
#include <common/StringRef.h>
#include <Core/Defines.h>
#include <Poco/SHA1Engine.h>
#include <boost/algorithm/hex.hpp>
#include "config_core.h"
#if USE_SSL
# include <openssl/sha.h>
#endif
namespace DB
{
namespace ErrorCodes
{
extern const int SUPPORT_IS_DISABLED;
extern const int REQUIRED_PASSWORD;
extern const int WRONG_PASSWORD;
extern const int BAD_ARGUMENTS;
extern const int LOGICAL_ERROR;
}
namespace
{
using Digest = Authentication::Digest;
Digest encodePlainText(const StringRef & text)
{
return Digest(text.data, text.data + text.size);
}
Digest encodeSHA256(const StringRef & text)
{
#if USE_SSL
Digest hash;
hash.resize(32);
SHA256_CTX ctx;
SHA256_Init(&ctx);
SHA256_Update(&ctx, reinterpret_cast<const UInt8 *>(text.data), text.size);
SHA256_Final(hash.data(), &ctx);
return hash;
#else
UNUSED(text);
throw DB::Exception("SHA256 passwords support is disabled, because ClickHouse was built without SSL library", DB::ErrorCodes::SUPPORT_IS_DISABLED);
#endif
}
Digest encodeSHA1(const StringRef & text)
{
Poco::SHA1Engine engine;
engine.update(text.data, text.size);
return engine.digest();
}
Digest encodeSHA1(const Digest & text)
{
return encodeSHA1(StringRef{reinterpret_cast<const char *>(text.data()), text.size()});
}
Digest encodeDoubleSHA1(const StringRef & text)
{
return encodeSHA1(encodeSHA1(text));
}
}
Authentication::Authentication(Authentication::Type type_)
: type(type_)
{
}
void Authentication::setPassword(const String & password_)
{
switch (type)
{
case NO_PASSWORD:
throw Exception("Cannot specify password for the 'NO_PASSWORD' authentication type", ErrorCodes::LOGICAL_ERROR);
case PLAINTEXT_PASSWORD:
setPasswordHashBinary(encodePlainText(password_));
return;
case SHA256_PASSWORD:
setPasswordHashBinary(encodeSHA256(password_));
return;
case DOUBLE_SHA1_PASSWORD:
setPasswordHashBinary(encodeDoubleSHA1(password_));
return;
}
throw Exception("Unknown authentication type: " + std::to_string(static_cast<int>(type)), ErrorCodes::LOGICAL_ERROR);
}
String Authentication::getPassword() const
{
if (type != PLAINTEXT_PASSWORD)
throw Exception("Cannot decode the password", ErrorCodes::LOGICAL_ERROR);
return String(password_hash.data(), password_hash.data() + password_hash.size());
}
void Authentication::setPasswordHashHex(const String & hash)
{
Digest digest;
digest.resize(hash.size() / 2);
boost::algorithm::unhex(hash.begin(), hash.end(), digest.data());
setPasswordHashBinary(digest);
}
String Authentication::getPasswordHashHex() const
{
String hex;
hex.resize(password_hash.size() * 2);
boost::algorithm::hex(password_hash.begin(), password_hash.end(), hex.data());
return hex;
}
void Authentication::setPasswordHashBinary(const Digest & hash)
{
switch (type)
{
case NO_PASSWORD:
throw Exception("Cannot specify password for the 'NO_PASSWORD' authentication type", ErrorCodes::LOGICAL_ERROR);
case PLAINTEXT_PASSWORD:
{
password_hash = hash;
return;
}
case SHA256_PASSWORD:
{
if (hash.size() != 32)
throw Exception(
"Password hash for the 'SHA256_PASSWORD' authentication type has length " + std::to_string(hash.size())
+ " but must be exactly 32 bytes.",
ErrorCodes::BAD_ARGUMENTS);
password_hash = hash;
return;
}
case DOUBLE_SHA1_PASSWORD:
{
if (hash.size() != 20)
throw Exception(
"Password hash for the 'DOUBLE_SHA1_PASSWORD' authentication type has length " + std::to_string(hash.size())
+ " but must be exactly 20 bytes.",
ErrorCodes::BAD_ARGUMENTS);
password_hash = hash;
return;
}
}
throw Exception("Unknown authentication type: " + std::to_string(static_cast<int>(type)), ErrorCodes::LOGICAL_ERROR);
}
Digest Authentication::getPasswordDoubleSHA1() const
Authentication::Digest Authentication::getPasswordDoubleSHA1() const
{
switch (type)
{
@ -198,12 +50,12 @@ bool Authentication::isCorrectPassword(const String & password_) const
case PLAINTEXT_PASSWORD:
{
if (password_ == StringRef{reinterpret_cast<const char *>(password_hash.data()), password_hash.size()})
if (password_ == std::string_view{reinterpret_cast<const char *>(password_hash.data()), password_hash.size()})
return true;
// For compatibility with MySQL clients which support only native authentication plugin, SHA1 can be passed instead of password.
auto password_sha1 = encodeSHA1(password_hash);
return password_ == StringRef{reinterpret_cast<const char *>(password_sha1.data()), password_sha1.size()};
return password_ == std::string_view{reinterpret_cast<const char *>(password_sha1.data()), password_sha1.size()};
}
case SHA256_PASSWORD:
@ -234,10 +86,5 @@ void Authentication::checkPassword(const String & password_, const String & user
throw Exception("Wrong password" + info_about_user_name(), ErrorCodes::WRONG_PASSWORD);
}
bool operator ==(const Authentication & lhs, const Authentication & rhs)
{
return (lhs.type == rhs.type) && (lhs.password_hash == rhs.password_hash);
}
}

View File

@ -1,10 +1,22 @@
#pragma once
#include <Core/Types.h>
#include <Common/Exception.h>
#include <Common/OpenSSLHelpers.h>
#include <Poco/SHA1Engine.h>
#include <boost/algorithm/hex.hpp>
namespace DB
{
namespace ErrorCodes
{
extern const int SUPPORT_IS_DISABLED;
extern const int BAD_ARGUMENTS;
extern const int LOGICAL_ERROR;
}
/// Authentication type and encrypted password for checking when an user logins.
class Authentication
{
@ -27,7 +39,7 @@ public:
using Digest = std::vector<UInt8>;
Authentication(Authentication::Type type = NO_PASSWORD);
Authentication(Authentication::Type type_ = NO_PASSWORD) : type(type_) {}
Authentication(const Authentication & src) = default;
Authentication & operator =(const Authentication & src) = default;
Authentication(Authentication && src) = default;
@ -36,17 +48,19 @@ public:
Type getType() const { return type; }
/// Sets the password and encrypt it using the authentication type set in the constructor.
void setPassword(const String & password);
void setPassword(const String & password_);
/// Returns the password. Allowed to use only for Type::PLAINTEXT_PASSWORD.
String getPassword() const;
/// Sets the password as a string of hexadecimal digits.
void setPasswordHashHex(const String & hash);
String getPasswordHashHex() const;
/// Sets the password in binary form.
void setPasswordHashBinary(const Digest & hash);
const Digest & getPasswordHashBinary() const { return password_hash; }
/// Returns SHA1(SHA1(password)) used by MySQL compatibility server for authentication.
@ -60,11 +74,124 @@ public:
/// `user_name` is only used for generating an error message if the password is incorrect.
void checkPassword(const String & password, const String & user_name = String()) const;
friend bool operator ==(const Authentication & lhs, const Authentication & rhs);
friend bool operator ==(const Authentication & lhs, const Authentication & rhs) { return (lhs.type == rhs.type) && (lhs.password_hash == rhs.password_hash); }
friend bool operator !=(const Authentication & lhs, const Authentication & rhs) { return !(lhs == rhs); }
private:
static Digest encodePlainText(const std::string_view & text) { return Digest(text.data(), text.data() + text.size()); }
static Digest encodeSHA256(const std::string_view & text);
static Digest encodeSHA1(const std::string_view & text);
static Digest encodeSHA1(const Digest & text) { return encodeSHA1(std::string_view{reinterpret_cast<const char *>(text.data()), text.size()}); }
static Digest encodeDoubleSHA1(const std::string_view & text) { return encodeSHA1(encodeSHA1(text)); }
Type type = Type::NO_PASSWORD;
Digest password_hash;
};
inline Authentication::Digest Authentication::encodeSHA256(const std::string_view & text [[maybe_unused]])
{
#if USE_SSL
Digest hash;
hash.resize(32);
::DB::encodeSHA256(text, hash.data());
return hash;
#else
throw DB::Exception(
"SHA256 passwords support is disabled, because ClickHouse was built without SSL library",
DB::ErrorCodes::SUPPORT_IS_DISABLED);
#endif
}
inline Authentication::Digest Authentication::encodeSHA1(const std::string_view & text)
{
Poco::SHA1Engine engine;
engine.update(text.data(), text.size());
return engine.digest();
}
inline void Authentication::setPassword(const String & password_)
{
switch (type)
{
case NO_PASSWORD:
throw Exception("Cannot specify password for the 'NO_PASSWORD' authentication type", ErrorCodes::LOGICAL_ERROR);
case PLAINTEXT_PASSWORD:
return setPasswordHashBinary(encodePlainText(password_));
case SHA256_PASSWORD:
return setPasswordHashBinary(encodeSHA256(password_));
case DOUBLE_SHA1_PASSWORD:
return setPasswordHashBinary(encodeDoubleSHA1(password_));
}
throw Exception("Unknown authentication type: " + std::to_string(static_cast<int>(type)), ErrorCodes::LOGICAL_ERROR);
}
inline String Authentication::getPassword() const
{
if (type != PLAINTEXT_PASSWORD)
throw Exception("Cannot decode the password", ErrorCodes::LOGICAL_ERROR);
return String(password_hash.data(), password_hash.data() + password_hash.size());
}
inline void Authentication::setPasswordHashHex(const String & hash)
{
Digest digest;
digest.resize(hash.size() / 2);
boost::algorithm::unhex(hash.begin(), hash.end(), digest.data());
setPasswordHashBinary(digest);
}
inline String Authentication::getPasswordHashHex() const
{
String hex;
hex.resize(password_hash.size() * 2);
boost::algorithm::hex(password_hash.begin(), password_hash.end(), hex.data());
return hex;
}
inline void Authentication::setPasswordHashBinary(const Digest & hash)
{
switch (type)
{
case NO_PASSWORD:
throw Exception("Cannot specify password for the 'NO_PASSWORD' authentication type", ErrorCodes::LOGICAL_ERROR);
case PLAINTEXT_PASSWORD:
{
password_hash = hash;
return;
}
case SHA256_PASSWORD:
{
if (hash.size() != 32)
throw Exception(
"Password hash for the 'SHA256_PASSWORD' authentication type has length " + std::to_string(hash.size())
+ " but must be exactly 32 bytes.",
ErrorCodes::BAD_ARGUMENTS);
password_hash = hash;
return;
}
case DOUBLE_SHA1_PASSWORD:
{
if (hash.size() != 20)
throw Exception(
"Password hash for the 'DOUBLE_SHA1_PASSWORD' authentication type has length " + std::to_string(hash.size())
+ " but must be exactly 20 bytes.",
ErrorCodes::BAD_ARGUMENTS);
password_hash = hash;
return;
}
}
throw Exception("Unknown authentication type: " + std::to_string(static_cast<int>(type)), ErrorCodes::LOGICAL_ERROR);
}
}

View File

@ -10,7 +10,8 @@ bool User::equal(const IAccessEntity & other) const
return false;
const auto & other_user = typeid_cast<const User &>(other);
return (authentication == other_user.authentication) && (allowed_client_hosts == other_user.allowed_client_hosts)
&& (access == other_user.access) && (profile == other_user.profile);
&& (access == other_user.access) && (access_with_grant_option == other_user.access_with_grant_option)
&& (profile == other_user.profile);
}
}

View File

@ -14,8 +14,9 @@ namespace DB
struct User : public IAccessEntity
{
Authentication authentication;
AllowedClientHosts allowed_client_hosts;
AllowedClientHosts allowed_client_hosts{AllowedClientHosts::AnyHostTag{}};
AccessRights access;
AccessRights access_with_grant_option;
String profile;
bool equal(const IAccessEntity & other) const override;

View File

@ -90,15 +90,16 @@ namespace
{
Poco::Util::AbstractConfiguration::Keys keys;
config.keys(networks_config, keys);
user->allowed_client_hosts.clear();
for (const String & key : keys)
{
String value = config.getString(networks_config + "." + key);
if (key.starts_with("ip"))
user->allowed_client_hosts.addSubnet(value);
else if (key.starts_with("host_regexp"))
user->allowed_client_hosts.addHostRegexp(value);
user->allowed_client_hosts.addNameRegexp(value);
else if (key.starts_with("host"))
user->allowed_client_hosts.addHostName(value);
user->allowed_client_hosts.addName(value);
else
throw Exception("Unknown address pattern type: " + key, ErrorCodes::UNKNOWN_ADDRESS_PATTERN_TYPE);
}
@ -143,7 +144,6 @@ namespace
user->access.fullRevoke(AccessFlags::databaseLevel());
for (const String & database : *databases)
user->access.grant(AccessFlags::databaseLevel(), database);
user->access.grant(AccessFlags::databaseLevel(), "system"); /// Anyone has access to the "system" database.
}
if (dictionaries)
@ -155,6 +155,8 @@ namespace
else if (databases)
user->access.grant(AccessType::dictGet, IDictionary::NO_DATABASE_TAG);
user->access_with_grant_option = user->access;
return user;
}

View File

@ -481,6 +481,7 @@ namespace ErrorCodes
extern const int UNABLE_TO_SKIP_UNUSED_SHARDS = 507;
extern const int UNKNOWN_ACCESS_TYPE = 508;
extern const int INVALID_GRANT = 509;
extern const int CACHE_DICTIONARY_UPDATE_FAIL = 510;
extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000;

View File

@ -3,11 +3,20 @@
#include "OpenSSLHelpers.h"
#include <ext/scope_guard.h>
#include <openssl/err.h>
#include <openssl/sha.h>
namespace DB
{
#pragma GCC diagnostic warning "-Wold-style-cast"
void encodeSHA256(const std::string_view & text, unsigned char * out)
{
SHA256_CTX ctx;
SHA256_Init(&ctx);
SHA256_Update(&ctx, reinterpret_cast<const UInt8 *>(text.data()), text.size());
SHA256_Final(out, &ctx);
}
String getOpenSSLErrors()
{
BIO * mem = BIO_new(BIO_s_mem());

View File

@ -7,6 +7,8 @@
namespace DB
{
/// Encodes `text` and puts the result to `out` which must be at least 32 bytes long.
void encodeSHA256(const std::string_view & text, unsigned char * out);
/// Returns concatenation of error strings for all errors that OpenSSL has recorded, emptying the error queue.
String getOpenSSLErrors();

View File

@ -23,6 +23,78 @@ extern const int LOGICAL_ERROR;
namespace
{
/// Fixed TypeIds that numbers would not be changed between versions.
enum class MagicNumber : uint8_t
{
UInt8 = 1,
UInt16 = 2,
UInt32 = 3,
UInt64 = 4,
Int8 = 6,
Int16 = 7,
Int32 = 8,
Int64 = 9,
Date = 13,
DateTime = 14,
DateTime64 = 15,
Enum8 = 17,
Enum16 = 18,
Decimal32 = 19,
Decimal64 = 20,
};
MagicNumber serializeTypeId(TypeIndex type_id)
{
switch (type_id)
{
case TypeIndex::UInt8: return MagicNumber::UInt8;
case TypeIndex::UInt16: return MagicNumber::UInt16;
case TypeIndex::UInt32: return MagicNumber::UInt32;
case TypeIndex::UInt64: return MagicNumber::UInt64;
case TypeIndex::Int8: return MagicNumber::Int8;
case TypeIndex::Int16: return MagicNumber::Int16;
case TypeIndex::Int32: return MagicNumber::Int32;
case TypeIndex::Int64: return MagicNumber::Int64;
case TypeIndex::Date: return MagicNumber::Date;
case TypeIndex::DateTime: return MagicNumber::DateTime;
case TypeIndex::DateTime64: return MagicNumber::DateTime64;
case TypeIndex::Enum8: return MagicNumber::Enum8;
case TypeIndex::Enum16: return MagicNumber::Enum16;
case TypeIndex::Decimal32: return MagicNumber::Decimal32;
case TypeIndex::Decimal64: return MagicNumber::Decimal64;
default:
break;
}
throw Exception("Type is not supported by T64 codec: " + toString(UInt32(type_id)), ErrorCodes::LOGICAL_ERROR);
}
TypeIndex deserializeTypeId(uint8_t serialized_type_id)
{
MagicNumber magic = static_cast<MagicNumber>(serialized_type_id);
switch (magic)
{
case MagicNumber::UInt8: return TypeIndex::UInt8;
case MagicNumber::UInt16: return TypeIndex::UInt16;
case MagicNumber::UInt32: return TypeIndex::UInt32;
case MagicNumber::UInt64: return TypeIndex::UInt64;
case MagicNumber::Int8: return TypeIndex::Int8;
case MagicNumber::Int16: return TypeIndex::Int16;
case MagicNumber::Int32: return TypeIndex::Int32;
case MagicNumber::Int64: return TypeIndex::Int64;
case MagicNumber::Date: return TypeIndex::Date;
case MagicNumber::DateTime: return TypeIndex::DateTime;
case MagicNumber::DateTime64: return TypeIndex::DateTime64;
case MagicNumber::Enum8: return TypeIndex::Enum8;
case MagicNumber::Enum16: return TypeIndex::Enum16;
case MagicNumber::Decimal32: return TypeIndex::Decimal32;
case MagicNumber::Decimal64: return TypeIndex::Decimal64;
}
throw Exception("Bad magic number in T64 codec: " + toString(UInt32(serialized_type_id)), ErrorCodes::LOGICAL_ERROR);
}
UInt8 codecId()
{
return static_cast<UInt8>(CompressionMethodByte::T64);
@ -41,6 +113,7 @@ TypeIndex baseType(TypeIndex type_idx)
return TypeIndex::Int32;
case TypeIndex::Int64:
case TypeIndex::Decimal64:
case TypeIndex::DateTime64:
return TypeIndex::Int64;
case TypeIndex::UInt8:
case TypeIndex::Enum8:
@ -79,6 +152,7 @@ TypeIndex typeIdx(const DataTypePtr & data_type)
case TypeIndex::Int32:
case TypeIndex::UInt32:
case TypeIndex::DateTime:
case TypeIndex::DateTime64:
case TypeIndex::Decimal32:
case TypeIndex::Int64:
case TypeIndex::UInt64:
@ -490,7 +564,7 @@ void decompressData(const char * src, UInt32 src_size, char * dst, UInt32 uncomp
UInt32 CompressionCodecT64::doCompressData(const char * src, UInt32 src_size, char * dst) const
{
UInt8 cookie = static_cast<UInt8>(type_idx) | (static_cast<UInt8>(variant) << 7);
UInt8 cookie = static_cast<UInt8>(serializeTypeId(type_idx)) | (static_cast<UInt8>(variant) << 7);
memcpy(dst, &cookie, 1);
dst += 1;
@ -529,7 +603,7 @@ void CompressionCodecT64::doDecompressData(const char * src, UInt32 src_size, ch
src_size -= 1;
auto saved_variant = static_cast<Variant>(cookie >> 7);
auto saved_type_id = static_cast<TypeIndex>(cookie & 0x7F);
TypeIndex saved_type_id = deserializeTypeId(cookie & 0x7F);
switch (baseType(saved_type_id))
{

View File

@ -7,6 +7,7 @@
#include <Common/PODArray.h>
#include <Core/Types.h>
#include <Interpreters/Context.h>
#include <Access/AccessControlManager.h>
#include <Access/User.h>
#include <IO/copyData.h>
#include <IO/LimitReadBuffer.h>
@ -952,7 +953,7 @@ public:
throw Exception("Wrong size of auth response. Expected: " + std::to_string(Poco::SHA1Engine::DIGEST_SIZE) + " bytes, received: " + std::to_string(auth_response->size()) + " bytes.",
ErrorCodes::UNKNOWN_EXCEPTION);
auto user = context.getUser(user_name);
auto user = context.getAccessControlManager().getUser(user_name);
Poco::SHA1Engine::Digest double_sha1_value = user->authentication.getPasswordDoubleSHA1();
assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE);

View File

@ -391,6 +391,9 @@ struct Settings : public SettingsCollection<Settings>
M(SettingUInt64, mutations_sync, 0, "Wait for synchronous execution of ALTER TABLE UPDATE/DELETE queries (mutations). 0 - execute asynchronously. 1 - wait current server. 2 - wait all replicas if they exist.", 0) \
M(SettingBool, optimize_if_chain_to_miltiif, false, "Replace if(cond1, then1, if(cond2, ...)) chains to multiIf. Currently it's not beneficial for numeric types.", 0) \
M(SettingBool, allow_experimental_alter_materialized_view_structure, false, "Allow atomic alter on Materialized views. Work in progress.", 0) \
M(SettingBool, enable_early_constant_folding, true, "Enable query optimization where we analyze function and subqueries results and rewrite query if there're constants there", 0) \
\
M(SettingBool, partial_revokes, false, "Makes it possible to revoke privileges partially.", 0) \
\
/** Obsolete settings that do nothing but left for compatibility reasons. Remove each one after half a year of obsolescence. */ \
\

View File

@ -14,6 +14,7 @@ namespace DB
struct Null {};
/// @note Except explicitly described you should not assume on TypeIndex numbers and/or their orders in this enum.
enum class TypeIndex
{
Nothing = 0,

View File

@ -621,6 +621,12 @@ inline bool isStringOrFixedString(const T & data_type)
return WhichDataType(data_type).isStringOrFixedString();
}
template <typename T>
inline bool isNotCreatable(const T & data_type)
{
WhichDataType which(data_type);
return which.isNothing() || which.isFunction() || which.isSet();
}
inline bool isNotDecimalButComparableToDecimal(const DataTypePtr & data_type)
{

View File

@ -12,6 +12,7 @@
#include <Common/typeid_cast.h>
#include <ext/range.h>
#include <ext/size.h>
#include <Common/setThreadName.h>
#include "CacheDictionary.inc.h"
#include "DictionaryBlockInputStream.h"
#include "DictionaryFactory.h"
@ -61,24 +62,48 @@ CacheDictionary::CacheDictionary(
const std::string & name_,
const DictionaryStructure & dict_struct_,
DictionarySourcePtr source_ptr_,
const DictionaryLifetime dict_lifetime_,
const size_t size_)
DictionaryLifetime dict_lifetime_,
size_t size_,
bool allow_read_expired_keys_,
size_t max_update_queue_size_,
size_t update_queue_push_timeout_milliseconds_,
size_t max_threads_for_updates_)
: database(database_)
, name(name_)
, full_name{database_.empty() ? name_ : (database_ + "." + name_)}
, dict_struct(dict_struct_)
, source_ptr{std::move(source_ptr_)}
, dict_lifetime(dict_lifetime_)
, allow_read_expired_keys(allow_read_expired_keys_)
, max_update_queue_size(max_update_queue_size_)
, update_queue_push_timeout_milliseconds(update_queue_push_timeout_milliseconds_)
, max_threads_for_updates(max_threads_for_updates_)
, log(&Logger::get("ExternalDictionaries"))
, size{roundUpToPowerOfTwoOrZero(std::max(size_, size_t(max_collision_length)))}
, size_overlap_mask{this->size - 1}
, cells{this->size}
, rnd_engine(randomSeed())
, update_queue(max_update_queue_size_)
, update_pool(max_threads_for_updates)
{
if (!this->source_ptr->supportsSelectiveLoad())
throw Exception{full_name + ": source cannot be used with CacheDictionary", ErrorCodes::UNSUPPORTED_METHOD};
createAttributes();
for (size_t i = 0; i < max_threads_for_updates; ++i)
update_pool.scheduleOrThrowOnError([this] { updateThreadFunction(); });
}
CacheDictionary::~CacheDictionary()
{
finished = true;
update_queue.clear();
for (size_t i = 0; i < max_threads_for_updates; ++i)
{
auto empty_finishing_ptr = std::make_shared<UpdateUnit>(std::vector<Key>());
update_queue.push(empty_finishing_ptr);
}
update_pool.wait();
}
@ -275,10 +300,16 @@ CacheDictionary::FindResult CacheDictionary::findCellIdx(const Key & id, const C
void CacheDictionary::has(const PaddedPODArray<Key> & ids, PaddedPODArray<UInt8> & out) const
{
/// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
std::unordered_map<Key, std::vector<size_t>> outdated_ids;
/// There are three types of ids.
/// - Valid ids. These ids are presented in local cache and their lifetime is not expired.
/// - CacheExpired ids. Ids that are in local cache, but their values are rotted (lifetime is expired).
/// - CacheNotFound ids. We have to go to external storage to know its value.
size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0;
/// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
std::unordered_map<Key, std::vector<size_t>> cache_expired_ids;
std::unordered_map<Key, std::vector<size_t>> cache_not_found_ids;
size_t cache_hit = 0;
const auto rows = ext::size(ids);
{
@ -291,49 +322,97 @@ void CacheDictionary::has(const PaddedPODArray<Key> & ids, PaddedPODArray<UInt8>
const auto id = ids[row];
const auto find_result = findCellIdx(id, now);
const auto & cell_idx = find_result.cell_idx;
auto insert_to_answer_routine = [&] ()
{
out[row] = !cells[cell_idx].isDefault();
};
if (!find_result.valid)
{
outdated_ids[id].push_back(row);
if (find_result.outdated)
++cache_expired;
{
cache_expired_ids[id].push_back(row);
if (allow_read_expired_keys)
insert_to_answer_routine();
}
else
++cache_not_found;
{
cache_not_found_ids[id].push_back(row);
}
}
else
{
++cache_hit;
const auto & cell = cells[cell_idx];
out[row] = !cell.isDefault();
insert_to_answer_routine();
}
}
}
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired);
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found);
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired_ids.size());
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found_ids.size());
ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit);
query_count.fetch_add(rows, std::memory_order_relaxed);
hit_count.fetch_add(rows - outdated_ids.size(), std::memory_order_release);
hit_count.fetch_add(rows - cache_expired_ids.size() - cache_not_found_ids.size(), std::memory_order_release);
if (outdated_ids.empty())
return;
if (cache_not_found_ids.empty())
{
/// Nothing to update - return;
if (cache_expired_ids.empty())
return;
std::vector<Key> required_ids(outdated_ids.size());
std::transform(std::begin(outdated_ids), std::end(outdated_ids), std::begin(required_ids), [](auto & pair) { return pair.first; });
/// request new values
update(
required_ids,
[&](const auto id, const auto)
if (allow_read_expired_keys)
{
for (const auto row : outdated_ids[id])
out[row] = true;
},
[&](const auto id, const auto)
{
for (const auto row : outdated_ids[id])
out[row] = false;
});
std::vector<Key> required_expired_ids;
required_expired_ids.reserve(cache_expired_ids.size());
std::transform(
std::begin(cache_expired_ids), std::end(cache_expired_ids),
std::back_inserter(required_expired_ids), [](auto & pair) { return pair.first; });
/// Callbacks are empty because we don't want to receive them after an unknown period of time.
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_expired_ids);
tryPushToUpdateQueueOrThrow(update_unit_ptr);
/// Update is async - no need to wait.
return;
}
}
/// At this point we have two situations.
/// There may be both types of keys: cache_expired_ids and cache_not_found_ids.
/// We will update them all synchronously.
std::vector<Key> required_ids;
required_ids.reserve(cache_not_found_ids.size() + cache_expired_ids.size());
std::transform(
std::begin(cache_not_found_ids), std::end(cache_not_found_ids),
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
std::transform(
std::begin(cache_expired_ids), std::end(cache_expired_ids),
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
auto on_cell_updated = [&] (const Key id, const size_t)
{
for (const auto row : cache_not_found_ids[id])
out[row] = true;
for (const auto row : cache_expired_ids[id])
out[row] = true;
};
auto on_id_not_found = [&] (const Key id, const size_t)
{
for (const auto row : cache_not_found_ids[id])
out[row] = false;
for (const auto row : cache_expired_ids[id])
out[row] = true;
};
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_ids, on_cell_updated, on_id_not_found);
tryPushToUpdateQueueOrThrow(update_unit_ptr);
waitForCurrentUpdateFinish(update_unit_ptr);
}
@ -590,7 +669,8 @@ void registerDictionaryCache(DictionaryFactory & factory)
DictionarySourcePtr source_ptr) -> DictionaryPtr
{
if (dict_struct.key)
throw Exception{"'key' is not supported for dictionary of layout 'cache'", ErrorCodes::UNSUPPORTED_METHOD};
throw Exception{"'key' is not supported for dictionary of layout 'cache'",
ErrorCodes::UNSUPPORTED_METHOD};
if (dict_struct.range_min || dict_struct.range_max)
throw Exception{full_name
@ -598,9 +678,11 @@ void registerDictionaryCache(DictionaryFactory & factory)
"for a dictionary of layout 'range_hashed'",
ErrorCodes::BAD_ARGUMENTS};
const auto & layout_prefix = config_prefix + ".layout";
const auto size = config.getInt(layout_prefix + ".cache.size_in_cells");
const size_t size = config.getUInt64(layout_prefix + ".cache.size_in_cells");
if (size == 0)
throw Exception{full_name + ": dictionary of layout 'cache' cannot have 0 cells", ErrorCodes::TOO_SMALL_BUFFER_SIZE};
throw Exception{full_name + ": dictionary of layout 'cache' cannot have 0 cells",
ErrorCodes::TOO_SMALL_BUFFER_SIZE};
const bool require_nonempty = config.getBool(config_prefix + ".require_nonempty", false);
if (require_nonempty)
@ -610,10 +692,284 @@ void registerDictionaryCache(DictionaryFactory & factory)
const String database = config.getString(config_prefix + ".database", "");
const String name = config.getString(config_prefix + ".name");
const DictionaryLifetime dict_lifetime{config, config_prefix + ".lifetime"};
return std::make_unique<CacheDictionary>(database, name, dict_struct, std::move(source_ptr), dict_lifetime, size);
const size_t max_update_queue_size =
config.getUInt64(layout_prefix + ".cache.max_update_queue_size", 100000);
if (max_update_queue_size == 0)
throw Exception{name + ": dictionary of layout 'cache' cannot have empty update queue of size 0",
ErrorCodes::TOO_SMALL_BUFFER_SIZE};
const bool allow_read_expired_keys =
config.getBool(layout_prefix + ".cache.allow_read_expired_keys", false);
const size_t update_queue_push_timeout_milliseconds =
config.getUInt64(layout_prefix + ".cache.update_queue_push_timeout_milliseconds", 10);
if (update_queue_push_timeout_milliseconds < 10)
throw Exception{name + ": dictionary of layout 'cache' have too little update_queue_push_timeout",
ErrorCodes::BAD_ARGUMENTS};
const size_t max_threads_for_updates =
config.getUInt64(layout_prefix + ".max_threads_for_updates", 4);
if (max_threads_for_updates == 0)
throw Exception{name + ": dictionary of layout 'cache' cannot have zero threads for updates.",
ErrorCodes::BAD_ARGUMENTS};
return std::make_unique<CacheDictionary>(
database, name, dict_struct, std::move(source_ptr), dict_lifetime, size,
allow_read_expired_keys, max_update_queue_size, update_queue_push_timeout_milliseconds,
max_threads_for_updates);
};
factory.registerLayout("cache", create_layout, false);
}
void CacheDictionary::updateThreadFunction()
{
setThreadName("AsyncUpdater");
while (!finished)
{
UpdateUnitPtr first_popped;
update_queue.pop(first_popped);
if (finished)
break;
/// Here we pop as many unit pointers from update queue as we can.
/// We fix current size to avoid livelock (or too long waiting),
/// when this thread pops from the queue and other threads push to the queue.
const size_t current_queue_size = update_queue.size();
if (current_queue_size > 0)
LOG_TRACE(log, "Performing bunch of keys update in cache dictionary with "
<< current_queue_size + 1 << " keys");
std::vector<UpdateUnitPtr> update_request;
update_request.reserve(current_queue_size + 1);
update_request.emplace_back(first_popped);
UpdateUnitPtr current_unit_ptr;
while (update_request.size() && update_queue.tryPop(current_unit_ptr))
update_request.emplace_back(std::move(current_unit_ptr));
BunchUpdateUnit bunch_update_unit(update_request);
try
{
/// Update a bunch of ids.
update(bunch_update_unit);
/// Notify all threads about finished updating the bunch of ids
/// where their own ids were included.
std::unique_lock<std::mutex> lock(update_mutex);
for (auto & unit_ptr: update_request)
unit_ptr->is_done = true;
is_update_finished.notify_all();
}
catch (...)
{
std::unique_lock<std::mutex> lock(update_mutex);
/// It is a big trouble, because one bad query can make other threads fail with not relative exception.
/// So at this point all threads (and queries) will receive the same exception.
for (auto & unit_ptr: update_request)
unit_ptr->current_exception = std::current_exception();
is_update_finished.notify_all();
}
}
}
void CacheDictionary::waitForCurrentUpdateFinish(UpdateUnitPtr & update_unit_ptr) const
{
std::unique_lock<std::mutex> lock(update_mutex);
/*
* We wait here without any timeout to avoid SEGFAULT's.
* Consider timeout for wait had expired and main query's thread ended with exception
* or some other error. But the UpdateUnit with callbacks is left in the queue.
* It has these callback that capture god knows what from the current thread
* (most of the variables lies on the stack of finished thread) that
* intended to do a synchronous update. AsyncUpdate thread can touch deallocated memory and explode.
* */
is_update_finished.wait(
lock,
[&] {return update_unit_ptr->is_done || update_unit_ptr->current_exception; });
if (update_unit_ptr->current_exception)
std::rethrow_exception(update_unit_ptr->current_exception);
}
void CacheDictionary::tryPushToUpdateQueueOrThrow(UpdateUnitPtr & update_unit_ptr) const
{
if (!update_queue.tryPush(update_unit_ptr, update_queue_push_timeout_milliseconds))
throw DB::Exception(
"Cannot push to internal update queue in dictionary " + getFullName() + ". Timelimit of " +
std::to_string(update_queue_push_timeout_milliseconds) + " ms. exceeded. Current queue size is " +
std::to_string(update_queue.size()), ErrorCodes::CACHE_DICTIONARY_UPDATE_FAIL);
}
void CacheDictionary::update(BunchUpdateUnit & bunch_update_unit) const
{
CurrentMetrics::Increment metric_increment{CurrentMetrics::DictCacheRequests};
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequested, bunch_update_unit.getRequestedIds().size());
std::unordered_map<Key, UInt8> remaining_ids{bunch_update_unit.getRequestedIds().size()};
for (const auto id : bunch_update_unit.getRequestedIds())
remaining_ids.insert({id, 0});
const auto now = std::chrono::system_clock::now();
if (now > backoff_end_time.load())
{
try
{
if (error_count)
{
/// Recover after error: we have to clone the source here because
/// it could keep connections which should be reset after error.
source_ptr = source_ptr->clone();
}
Stopwatch watch;
auto stream = source_ptr->loadIds(bunch_update_unit.getRequestedIds());
const ProfilingScopedWriteRWLock write_lock{rw_lock, ProfileEvents::DictCacheLockWriteNs};
stream->readPrefix();
while (const auto block = stream->read())
{
const auto id_column = typeid_cast<const ColumnUInt64 *>(block.safeGetByPosition(0).column.get());
if (!id_column)
throw Exception{name + ": id column has type different from UInt64.", ErrorCodes::TYPE_MISMATCH};
const auto & ids = id_column->getData();
/// cache column pointers
const auto column_ptrs = ext::map<std::vector>(
ext::range(0, attributes.size()), [&block](size_t i) { return block.safeGetByPosition(i + 1).column.get(); });
for (const auto i : ext::range(0, ids.size()))
{
const auto id = ids[i];
const auto find_result = findCellIdx(id, now);
const auto & cell_idx = find_result.cell_idx;
auto & cell = cells[cell_idx];
for (const auto attribute_idx : ext::range(0, attributes.size()))
{
const auto & attribute_column = *column_ptrs[attribute_idx];
auto & attribute = attributes[attribute_idx];
setAttributeValue(attribute, cell_idx, attribute_column[i]);
}
/// if cell id is zero and zero does not map to this cell, then the cell is unused
if (cell.id == 0 && cell_idx != zero_cell_idx)
element_count.fetch_add(1, std::memory_order_relaxed);
cell.id = id;
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
{
std::uniform_int_distribution<UInt64> distribution{dict_lifetime.min_sec, dict_lifetime.max_sec};
cell.setExpiresAt(now + std::chrono::seconds{distribution(rnd_engine)});
}
else
cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
bunch_update_unit.informCallersAboutPresentId(id, cell_idx);
/// mark corresponding id as found
remaining_ids[id] = 1;
}
}
stream->readSuffix();
error_count = 0;
last_exception = std::exception_ptr{};
backoff_end_time = std::chrono::system_clock::time_point{};
ProfileEvents::increment(ProfileEvents::DictCacheRequestTimeNs, watch.elapsed());
}
catch (...)
{
++error_count;
last_exception = std::current_exception();
backoff_end_time = now + std::chrono::seconds(calculateDurationWithBackoff(rnd_engine, error_count));
tryLogException(last_exception, log, "Could not update cache dictionary '" + getFullName() +
"', next update is scheduled at " + ext::to_string(backoff_end_time.load()));
}
}
size_t not_found_num = 0, found_num = 0;
const ProfilingScopedWriteRWLock write_lock{rw_lock, ProfileEvents::DictCacheLockWriteNs};
/// Check which ids have not been found and require setting null_value
for (const auto & id_found_pair : remaining_ids)
{
if (id_found_pair.second)
{
++found_num;
continue;
}
++not_found_num;
const auto id = id_found_pair.first;
const auto find_result = findCellIdx(id, now);
const auto & cell_idx = find_result.cell_idx;
auto & cell = cells[cell_idx];
if (error_count)
{
if (find_result.outdated)
{
/// We have expired data for that `id` so we can continue using it.
bool was_default = cell.isDefault();
cell.setExpiresAt(backoff_end_time);
if (was_default)
cell.setDefault();
if (was_default)
bunch_update_unit.informCallersAboutAbsentId(id, cell_idx);
else
bunch_update_unit.informCallersAboutPresentId(id, cell_idx);
continue;
}
/// We don't have expired data for that `id` so all we can do is to rethrow `last_exception`.
std::rethrow_exception(last_exception);
}
/// Check if cell had not been occupied before and increment element counter if it hadn't
if (cell.id == 0 && cell_idx != zero_cell_idx)
element_count.fetch_add(1, std::memory_order_relaxed);
cell.id = id;
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
{
std::uniform_int_distribution<UInt64> distribution{dict_lifetime.min_sec, dict_lifetime.max_sec};
cell.setExpiresAt(now + std::chrono::seconds{distribution(rnd_engine)});
}
else
cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
/// Set null_value for each attribute
cell.setDefault();
for (auto & attribute : attributes)
setDefaultAttributeValue(attribute, cell_idx);
/// inform caller that the cell has not been found
bunch_update_unit.informCallersAboutAbsentId(id, cell_idx);
}
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequestedMiss, not_found_num);
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequestedFound, found_num);
ProfileEvents::increment(ProfileEvents::DictCacheRequests);
}
}

View File

@ -4,12 +4,16 @@
#include <chrono>
#include <cmath>
#include <map>
#include <mutex>
#include <shared_mutex>
#include <utility>
#include <variant>
#include <vector>
#include <common/logger_useful.h>
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnString.h>
#include <Common/ThreadPool.h>
#include <Common/ConcurrentBoundedQueue.h>
#include <pcg_random.hpp>
#include <Common/ArenaWithFreeLists.h>
#include <Common/CurrentMetrics.h>
@ -21,6 +25,22 @@
namespace DB
{
namespace ErrorCodes
{
extern const int CACHE_DICTIONARY_UPDATE_FAIL;
}
/*
*
* This dictionary is stored in a cache that has a fixed number of cells.
* These cells contain frequently used elements.
* When searching for a dictionary, the cache is searched first and special heuristic is used:
* while looking for the key, we take a look only at max_collision_length elements.
* So, our cache is not perfect. It has errors like "the key is in cache, but the cache says that it does not".
* And in this case we simply ask external source for the key which is faster.
* You have to keep this logic in mind.
* */
class CacheDictionary final : public IDictionary
{
public:
@ -29,8 +49,14 @@ public:
const std::string & name_,
const DictionaryStructure & dict_struct_,
DictionarySourcePtr source_ptr_,
const DictionaryLifetime dict_lifetime_,
const size_t size_);
DictionaryLifetime dict_lifetime_,
size_t size_,
bool allow_read_expired_keys_,
size_t max_update_queue_size_,
size_t update_queue_push_timeout_milliseconds_,
size_t max_threads_for_updates);
~CacheDictionary() override;
const std::string & getDatabase() const override { return database; }
const std::string & getName() const override { return name; }
@ -55,7 +81,10 @@ public:
std::shared_ptr<const IExternalLoadable> clone() const override
{
return std::make_shared<CacheDictionary>(database, name, dict_struct, source_ptr->clone(), dict_lifetime, size);
return std::make_shared<CacheDictionary>(
database, name, dict_struct, source_ptr->clone(), dict_lifetime, size,
allow_read_expired_keys, max_update_queue_size,
update_queue_push_timeout_milliseconds, max_threads_for_updates);
}
const IDictionarySource * getSource() const override { return source_ptr.get(); }
@ -230,9 +259,6 @@ private:
template <typename DefaultGetter>
void getItemsString(Attribute & attribute, const PaddedPODArray<Key> & ids, ColumnString * out, DefaultGetter && get_default) const;
template <typename PresentIdHandler, typename AbsentIdHandler>
void update(const std::vector<Key> & requested_ids, PresentIdHandler && on_cell_updated, AbsentIdHandler && on_id_not_found) const;
PaddedPODArray<Key> getCachedIds() const;
bool isEmptyCell(const UInt64 idx) const;
@ -263,6 +289,11 @@ private:
const DictionaryStructure dict_struct;
mutable DictionarySourcePtr source_ptr;
const DictionaryLifetime dict_lifetime;
const bool allow_read_expired_keys;
const size_t max_update_queue_size;
const size_t update_queue_push_timeout_milliseconds;
const size_t max_threads_for_updates;
Logger * const log;
mutable std::shared_mutex rw_lock;
@ -284,8 +315,8 @@ private:
std::unique_ptr<ArenaWithFreeLists> string_arena;
mutable std::exception_ptr last_exception;
mutable size_t error_count = 0;
mutable std::chrono::system_clock::time_point backoff_end_time;
mutable std::atomic<size_t> error_count = 0;
mutable std::atomic<std::chrono::system_clock::time_point> backoff_end_time{std::chrono::system_clock::time_point{}};
mutable pcg64 rnd_engine;
@ -293,6 +324,166 @@ private:
mutable std::atomic<size_t> element_count{0};
mutable std::atomic<size_t> hit_count{0};
mutable std::atomic<size_t> query_count{0};
};
/// Field and methods correlated with update expired and not found keys
using PresentIdHandler = std::function<void(Key, size_t)>;
using AbsentIdHandler = std::function<void(Key, size_t)>;
/*
* Disclaimer: this comment is written not for fun.
*
* How the update goes: we basically have a method like get(keys)->values. Values are cached, so sometimes we
* can return them from the cache. For values not in cache, we query them from the dictionary, and add to the
* cache. The cache is lossy, so we can't expect it to store all the keys, and we store them separately. Normally,
* they would be passed as a return value of get(), but for Unknown Reasons the dictionaries use a baroque
* interface where get() accepts two callback, one that it calls for found values, and one for not found.
*
* Now we make it even uglier by doing this from multiple threads. The missing values are retreived from the
* dictionary in a background thread, and this thread calls the provided callback. So if you provide the callbacks,
* you MUST wait until the background update finishes, or god knows what happens. Unfortunately, we have no
* way to check that you did this right, so good luck.
*/
struct UpdateUnit
{
UpdateUnit(std::vector<Key> requested_ids_,
PresentIdHandler present_id_handler_,
AbsentIdHandler absent_id_handler_) :
requested_ids(std::move(requested_ids_)),
present_id_handler(present_id_handler_),
absent_id_handler(absent_id_handler_) {}
explicit UpdateUnit(std::vector<Key> requested_ids_) :
requested_ids(std::move(requested_ids_)),
present_id_handler([](Key, size_t){}),
absent_id_handler([](Key, size_t){}) {}
std::vector<Key> requested_ids;
PresentIdHandler present_id_handler;
AbsentIdHandler absent_id_handler;
std::atomic<bool> is_done{false};
std::exception_ptr current_exception{nullptr};
};
using UpdateUnitPtr = std::shared_ptr<UpdateUnit>;
using UpdateQueue = ConcurrentBoundedQueue<UpdateUnitPtr>;
/*
* This class is used to concatenate requested_keys.
*
* Imagine that we have several UpdateUnit with different vectors of keys and callbacks for that keys.
* We concatenate them into a long vector of keys that looks like:
*
* a1...ak_a b1...bk_2 c1...ck_3,
*
* where a1...ak_a are requested_keys from the first UpdateUnit.
* In addition we have the same number (three in this case) of callbacks.
* This class helps us to find a callback (or many callbacks) for a special key.
* */
class BunchUpdateUnit
{
public:
explicit BunchUpdateUnit(std::vector<UpdateUnitPtr> & update_request)
{
/// Here we prepare total count of all requested ids
/// not to do useless allocations later.
size_t total_requested_keys_count = 0;
for (auto & unit_ptr: update_request)
{
total_requested_keys_count += unit_ptr->requested_ids.size();
if (helper.empty())
helper.push_back(unit_ptr->requested_ids.size());
else
helper.push_back(unit_ptr->requested_ids.size() + helper.back());
present_id_handlers.emplace_back(unit_ptr->present_id_handler);
absent_id_handlers.emplace_back(unit_ptr->absent_id_handler);
}
concatenated_requested_ids.reserve(total_requested_keys_count);
for (auto & unit_ptr: update_request)
std::for_each(std::begin(unit_ptr->requested_ids), std::end(unit_ptr->requested_ids),
[&] (const Key & key) {concatenated_requested_ids.push_back(key);});
}
const std::vector<Key> & getRequestedIds()
{
return concatenated_requested_ids;
}
void informCallersAboutPresentId(Key id, size_t cell_idx)
{
for (size_t i = 0; i < concatenated_requested_ids.size(); ++i)
{
auto & curr = concatenated_requested_ids[i];
if (curr == id)
getPresentIdHandlerForPosition(i)(id, cell_idx);
}
}
void informCallersAboutAbsentId(Key id, size_t cell_idx)
{
for (size_t i = 0; i < concatenated_requested_ids.size(); ++i)
if (concatenated_requested_ids[i] == id)
getAbsentIdHandlerForPosition(i)(id, cell_idx);
}
private:
PresentIdHandler & getPresentIdHandlerForPosition(size_t position)
{
return present_id_handlers[getUpdateUnitNumberForRequestedIdPosition(position)];
}
AbsentIdHandler & getAbsentIdHandlerForPosition(size_t position)
{
return absent_id_handlers[getUpdateUnitNumberForRequestedIdPosition((position))];
}
size_t getUpdateUnitNumberForRequestedIdPosition(size_t position)
{
return std::lower_bound(helper.begin(), helper.end(), position) - helper.begin();
}
std::vector<Key> concatenated_requested_ids;
std::vector<PresentIdHandler> present_id_handlers;
std::vector<AbsentIdHandler> absent_id_handlers;
std::vector<size_t> helper;
};
mutable UpdateQueue update_queue;
ThreadPool update_pool;
/*
* Actually, we can divide all requested keys into two 'buckets'. There are only four possible states and they
* are described in the table.
*
* cache_not_found_ids |0|0|1|1|
* cache_expired_ids |0|1|0|1|
*
* 0 - if set is empty, 1 - otherwise
*
* Only if there are no cache_not_found_ids and some cache_expired_ids
* (with allow_read_expired_keys_from_cache_dictionary setting) we can perform async update.
* Otherwise we have no concatenate ids and update them sync.
*
*/
void updateThreadFunction();
void update(BunchUpdateUnit & bunch_update_unit) const;
void tryPushToUpdateQueueOrThrow(UpdateUnitPtr & update_unit_ptr) const;
void waitForCurrentUpdateFinish(UpdateUnitPtr & update_unit_ptr) const;
mutable std::mutex update_mutex;
mutable std::condition_variable is_update_finished;
std::atomic<bool> finished{false};
};
}

View File

@ -40,11 +40,13 @@ void CacheDictionary::getItemsNumberImpl(
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const
{
/// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
std::unordered_map<Key, std::vector<size_t>> outdated_ids;
std::unordered_map<Key, std::vector<size_t>> cache_expired_ids;
std::unordered_map<Key, std::vector<size_t>> cache_not_found_ids;
auto & attribute_array = std::get<ContainerPtrType<AttributeType>>(attribute.arrays);
const auto rows = ext::size(ids);
size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0;
size_t cache_hit = 0;
{
const ProfilingScopedReadRWLock read_lock{rw_lock, ProfileEvents::DictCacheLockReadNs};
@ -61,52 +63,105 @@ void CacheDictionary::getItemsNumberImpl(
* 3. explicit defaults were specified and cell was set default. */
const auto find_result = findCellIdx(id, now);
auto update_routine = [&]()
{
const auto & cell_idx = find_result.cell_idx;
const auto & cell = cells[cell_idx];
out[row] = cell.isDefault() ? get_default(row) : static_cast<OutputType>(attribute_array[cell_idx]);
};
if (!find_result.valid)
{
outdated_ids[id].push_back(row);
if (find_result.outdated)
++cache_expired;
{
cache_expired_ids[id].push_back(row);
if (allow_read_expired_keys)
update_routine();
}
else
++cache_not_found;
{
cache_not_found_ids[id].push_back(row);
}
}
else
{
++cache_hit;
const auto & cell_idx = find_result.cell_idx;
const auto & cell = cells[cell_idx];
out[row] = cell.isDefault() ? get_default(row) : static_cast<OutputType>(attribute_array[cell_idx]);
update_routine();
}
}
}
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired);
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found);
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired_ids.size());
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found_ids.size());
ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit);
query_count.fetch_add(rows, std::memory_order_relaxed);
hit_count.fetch_add(rows - outdated_ids.size(), std::memory_order_release);
hit_count.fetch_add(rows - cache_expired_ids.size() - cache_not_found_ids.size(), std::memory_order_release);
if (outdated_ids.empty())
return;
if (cache_not_found_ids.empty())
{
/// Nothing to update - return
if (cache_expired_ids.empty())
return;
std::vector<Key> required_ids(outdated_ids.size());
std::transform(std::begin(outdated_ids), std::end(outdated_ids), std::begin(required_ids), [](auto & pair) { return pair.first; });
/// request new values
update(
required_ids,
[&](const auto id, const auto cell_idx)
/// Update async only if allow_read_expired_keys_is_enabledadd condvar usage and better code
if (allow_read_expired_keys)
{
const auto attribute_value = attribute_array[cell_idx];
std::vector<Key> required_expired_ids;
required_expired_ids.reserve(cache_expired_ids.size());
std::transform(std::begin(cache_expired_ids), std::end(cache_expired_ids), std::back_inserter(required_expired_ids),
[](auto & pair) { return pair.first; });
for (const size_t row : outdated_ids[id])
out[row] = static_cast<OutputType>(attribute_value);
},
[&](const auto id, const auto)
{
for (const size_t row : outdated_ids[id])
out[row] = get_default(row);
});
/// request new values
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_expired_ids);
tryPushToUpdateQueueOrThrow(update_unit_ptr);
/// Nothing to do - return
return;
}
}
/// From this point we have to update all keys sync.
/// Maybe allow_read_expired_keys_from_cache_dictionary is disabled
/// and there no cache_not_found_ids but some cache_expired.
std::vector<Key> required_ids;
required_ids.reserve(cache_not_found_ids.size() + cache_expired_ids.size());
std::transform(
std::begin(cache_not_found_ids), std::end(cache_not_found_ids),
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
std::transform(
std::begin(cache_expired_ids), std::end(cache_expired_ids),
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
auto on_cell_updated = [&] (const auto id, const auto cell_idx)
{
const auto attribute_value = attribute_array[cell_idx];
for (const size_t row : cache_not_found_ids[id])
out[row] = static_cast<OutputType>(attribute_value);
for (const size_t row : cache_expired_ids[id])
out[row] = static_cast<OutputType>(attribute_value);
};
auto on_id_not_found = [&] (const auto id, const auto)
{
for (const size_t row : cache_not_found_ids[id])
out[row] = get_default(row);
for (const size_t row : cache_expired_ids[id])
out[row] = get_default(row);
};
/// Request new values
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_ids, on_cell_updated, on_id_not_found);
tryPushToUpdateQueueOrThrow(update_unit_ptr);
waitForCurrentUpdateFinish(update_unit_ptr);
}
template <typename DefaultGetter>
@ -161,12 +216,13 @@ void CacheDictionary::getItemsString(
out->getOffsets().resize_assume_reserved(0);
/// Mapping: <id> -> { all indices `i` of `ids` such that `ids[i]` = <id> }
std::unordered_map<Key, std::vector<size_t>> outdated_ids;
std::unordered_map<Key, std::vector<size_t>> cache_expired_ids;
std::unordered_map<Key, std::vector<size_t>> cache_not_found_ids;
/// we are going to store every string separately
std::unordered_map<Key, String> map;
size_t total_length = 0;
size_t cache_expired = 0, cache_not_found = 0, cache_hit = 0;
size_t cache_hit = 0;
{
const ProfilingScopedReadRWLock read_lock{rw_lock, ProfileEvents::DictCacheLockReadNs};
@ -176,17 +232,10 @@ void CacheDictionary::getItemsString(
const auto id = ids[row];
const auto find_result = findCellIdx(id, now);
if (!find_result.valid)
auto insert_value_routine = [&]()
{
outdated_ids[id].push_back(row);
if (find_result.outdated)
++cache_expired;
else
++cache_not_found;
}
else
{
++cache_hit;
const auto & cell_idx = find_result.cell_idx;
const auto & cell = cells[cell_idx];
const auto string_ref = cell.isDefault() ? get_default(row) : attribute_array[cell_idx];
@ -195,37 +244,82 @@ void CacheDictionary::getItemsString(
map[id] = String{string_ref};
total_length += string_ref.size + 1;
};
if (!find_result.valid)
{
if (find_result.outdated)
{
cache_expired_ids[id].push_back(row);
if (allow_read_expired_keys)
insert_value_routine();
} else
cache_not_found_ids[id].push_back(row);
} else
{
++cache_hit;
insert_value_routine();
}
}
}
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired);
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found);
ProfileEvents::increment(ProfileEvents::DictCacheKeysExpired, cache_expired_ids.size());
ProfileEvents::increment(ProfileEvents::DictCacheKeysNotFound, cache_not_found_ids.size());
ProfileEvents::increment(ProfileEvents::DictCacheKeysHit, cache_hit);
query_count.fetch_add(rows, std::memory_order_relaxed);
hit_count.fetch_add(rows - outdated_ids.size(), std::memory_order_release);
hit_count.fetch_add(rows - cache_expired_ids.size() - cache_not_found_ids.size(), std::memory_order_release);
/// request new values
if (!outdated_ids.empty())
/// Async update of expired keys.
if (cache_not_found_ids.empty())
{
std::vector<Key> required_ids(outdated_ids.size());
std::transform(std::begin(outdated_ids), std::end(outdated_ids), std::begin(required_ids), [](auto & pair) { return pair.first; });
if (allow_read_expired_keys && !cache_expired_ids.empty())
{
std::vector<Key> required_expired_ids;
required_expired_ids.reserve(cache_not_found_ids.size());
std::transform(std::begin(cache_expired_ids), std::end(cache_expired_ids),
std::back_inserter(required_expired_ids), [](auto & pair) { return pair.first; });
update(
required_ids,
[&](const auto id, const auto cell_idx)
{
const auto attribute_value = attribute_array[cell_idx];
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_expired_ids);
map[id] = String{attribute_value};
total_length += (attribute_value.size + 1) * outdated_ids[id].size();
},
[&](const auto id, const auto)
{
for (const auto row : outdated_ids[id])
total_length += get_default(row).size + 1;
});
tryPushToUpdateQueueOrThrow(update_unit_ptr);
/// Do not return at this point, because there some extra stuff to do at the end of this method.
}
}
/// Request new values sync.
/// We have request both cache_not_found_ids and cache_expired_ids.
if (!cache_not_found_ids.empty())
{
std::vector<Key> required_ids;
required_ids.reserve(cache_not_found_ids.size() + cache_expired_ids.size());
std::transform(
std::begin(cache_not_found_ids), std::end(cache_not_found_ids),
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
std::transform(
std::begin(cache_expired_ids), std::end(cache_expired_ids),
std::back_inserter(required_ids), [](auto & pair) { return pair.first; });
auto on_cell_updated = [&] (const auto id, const auto cell_idx)
{
const auto attribute_value = attribute_array[cell_idx];
map[id] = String{attribute_value};
total_length += (attribute_value.size + 1) * cache_not_found_ids[id].size();
};
auto on_id_not_found = [&] (const auto id, const auto)
{
for (const auto row : cache_not_found_ids[id])
total_length += get_default(row).size + 1;
};
auto update_unit_ptr = std::make_shared<UpdateUnit>(required_ids, on_cell_updated, on_id_not_found);
tryPushToUpdateQueueOrThrow(update_unit_ptr);
waitForCurrentUpdateFinish(update_unit_ptr);
}
out->getChars().reserve(total_length);
@ -240,167 +334,4 @@ void CacheDictionary::getItemsString(
}
}
template <typename PresentIdHandler, typename AbsentIdHandler>
void CacheDictionary::update(
const std::vector<Key> & requested_ids, PresentIdHandler && on_cell_updated, AbsentIdHandler && on_id_not_found) const
{
CurrentMetrics::Increment metric_increment{CurrentMetrics::DictCacheRequests};
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequested, requested_ids.size());
std::unordered_map<Key, UInt8> remaining_ids{requested_ids.size()};
for (const auto id : requested_ids)
remaining_ids.insert({id, 0});
const auto now = std::chrono::system_clock::now();
const ProfilingScopedWriteRWLock write_lock{rw_lock, ProfileEvents::DictCacheLockWriteNs};
if (now > backoff_end_time)
{
try
{
if (error_count)
{
/// Recover after error: we have to clone the source here because
/// it could keep connections which should be reset after error.
source_ptr = source_ptr->clone();
}
Stopwatch watch;
auto stream = source_ptr->loadIds(requested_ids);
stream->readPrefix();
while (const auto block = stream->read())
{
const auto id_column = typeid_cast<const ColumnUInt64 *>(block.safeGetByPosition(0).column.get());
if (!id_column)
throw Exception{name + ": id column has type different from UInt64.", ErrorCodes::TYPE_MISMATCH};
const auto & ids = id_column->getData();
/// cache column pointers
const auto column_ptrs = ext::map<std::vector>(
ext::range(0, attributes.size()), [&block](size_t i) { return block.safeGetByPosition(i + 1).column.get(); });
for (const auto i : ext::range(0, ids.size()))
{
const auto id = ids[i];
const auto find_result = findCellIdx(id, now);
const auto & cell_idx = find_result.cell_idx;
auto & cell = cells[cell_idx];
for (const auto attribute_idx : ext::range(0, attributes.size()))
{
const auto & attribute_column = *column_ptrs[attribute_idx];
auto & attribute = attributes[attribute_idx];
setAttributeValue(attribute, cell_idx, attribute_column[i]);
}
/// if cell id is zero and zero does not map to this cell, then the cell is unused
if (cell.id == 0 && cell_idx != zero_cell_idx)
element_count.fetch_add(1, std::memory_order_relaxed);
cell.id = id;
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
{
std::uniform_int_distribution<UInt64> distribution{dict_lifetime.min_sec, dict_lifetime.max_sec};
cell.setExpiresAt(now + std::chrono::seconds{distribution(rnd_engine)});
}
else
cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
/// inform caller
on_cell_updated(id, cell_idx);
/// mark corresponding id as found
remaining_ids[id] = 1;
}
}
stream->readSuffix();
error_count = 0;
last_exception = std::exception_ptr{};
backoff_end_time = std::chrono::system_clock::time_point{};
ProfileEvents::increment(ProfileEvents::DictCacheRequestTimeNs, watch.elapsed());
}
catch (...)
{
++error_count;
last_exception = std::current_exception();
backoff_end_time = now + std::chrono::seconds(calculateDurationWithBackoff(rnd_engine, error_count));
tryLogException(last_exception, log, "Could not update cache dictionary '" + getFullName() +
"', next update is scheduled at " + ext::to_string(backoff_end_time));
}
}
size_t not_found_num = 0, found_num = 0;
/// Check which ids have not been found and require setting null_value
for (const auto & id_found_pair : remaining_ids)
{
if (id_found_pair.second)
{
++found_num;
continue;
}
++not_found_num;
const auto id = id_found_pair.first;
const auto find_result = findCellIdx(id, now);
const auto & cell_idx = find_result.cell_idx;
auto & cell = cells[cell_idx];
if (error_count)
{
if (find_result.outdated)
{
/// We have expired data for that `id` so we can continue using it.
bool was_default = cell.isDefault();
cell.setExpiresAt(backoff_end_time);
if (was_default)
cell.setDefault();
if (was_default)
on_id_not_found(id, cell_idx);
else
on_cell_updated(id, cell_idx);
continue;
}
/// We don't have expired data for that `id` so all we can do is to rethrow `last_exception`.
std::rethrow_exception(last_exception);
}
/// Check if cell had not been occupied before and increment element counter if it hadn't
if (cell.id == 0 && cell_idx != zero_cell_idx)
element_count.fetch_add(1, std::memory_order_relaxed);
cell.id = id;
if (dict_lifetime.min_sec != 0 && dict_lifetime.max_sec != 0)
{
std::uniform_int_distribution<UInt64> distribution{dict_lifetime.min_sec, dict_lifetime.max_sec};
cell.setExpiresAt(now + std::chrono::seconds{distribution(rnd_engine)});
}
else
cell.setExpiresAt(std::chrono::time_point<std::chrono::system_clock>::max());
/// Set null_value for each attribute
cell.setDefault();
for (auto & attribute : attributes)
setDefaultAttributeValue(attribute, cell_idx);
/// inform caller that the cell has not been found
on_id_not_found(id, cell_idx);
}
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequestedMiss, not_found_num);
ProfileEvents::increment(ProfileEvents::DictCacheKeysRequestedFound, found_num);
ProfileEvents::increment(ProfileEvents::DictCacheRequests);
}
}

View File

@ -73,9 +73,15 @@ public:
for (size_t i = 0; i < input_rows_count; ++i)
{
StringRef source = column_concrete->getDataAt(i);
int status = 0;
std::string demangled = demangle(source.data, status);
result_column->insertDataWithTerminatingZero(demangled.data(), demangled.size() + 1);
auto demangled = tryDemangle(source.data);
if (demangled)
{
result_column->insertDataWithTerminatingZero(demangled.get(), strlen(demangled.get()));
}
else
{
result_column->insertDataWithTerminatingZero(source.data, source.size);
}
}
block.getByPosition(result).column = std::move(result_column);

View File

@ -688,7 +688,7 @@ void Context::calculateAccessRights()
{
auto lock = getLock();
if (user)
std::atomic_store(&access_rights, getAccessControlManager().getAccessRightsContext(client_info, user->access, settings, current_database));
std::atomic_store(&access_rights, getAccessControlManager().getAccessRightsContext(user, client_info, settings, current_database));
}
void Context::setProfile(const String & profile)
@ -701,9 +701,18 @@ void Context::setProfile(const String & profile)
settings_constraints = std::move(new_constraints);
}
std::shared_ptr<const User> Context::getUser(const String & user_name) const
std::shared_ptr<const User> Context::getUser() const
{
return shared->access_control_manager.getUser(user_name);
if (!user)
throw Exception("No current user", ErrorCodes::LOGICAL_ERROR);
return user;
}
UUID Context::getUserID() const
{
if (!user)
throw Exception("No current user", ErrorCodes::LOGICAL_ERROR);
return user_id;
}
void Context::setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key)
@ -717,8 +726,9 @@ void Context::setUser(const String & name, const String & password, const Poco::
if (!quota_key.empty())
client_info.quota_key = quota_key;
user_id = shared->access_control_manager.getID<User>(name);
user = shared->access_control_manager.authorizeAndGetUser(
name,
user_id,
password,
address.host(),
[this](const UserPtr & changed_user)

View File

@ -4,6 +4,7 @@
#include <Core/NamesAndTypes.h>
#include <Core/Settings.h>
#include <Core/Types.h>
#include <Core/UUID.h>
#include <DataStreams/IBlockStream_fwd.h>
#include <Interpreters/ClientInfo.h>
#include <Parsers/IAST_fwd.h>
@ -161,6 +162,7 @@ private:
InputBlocksReader input_blocks_reader;
std::shared_ptr<const User> user;
UUID user_id;
SubscriptionForUserChange subscription_for_user_change;
std::shared_ptr<const AccessRightsContext> access_rights;
std::shared_ptr<QuotaContext> quota; /// Current quota. By default - empty quota, that have no limits.
@ -260,10 +262,8 @@ public:
/// Must be called before getClientInfo.
void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key);
std::shared_ptr<const User> getUser() const { return user; }
/// Used by MySQL Secure Password Authentication plugin.
std::shared_ptr<const User> getUser(const String & user_name) const;
std::shared_ptr<const User> getUser() const;
UUID getUserID() const;
/// We have to copy external tables inside executeQuery() to track limits. Therefore, set callback for it. Must set once.
void setExternalTablesInitializer(ExternalTablesInitializer && initializer);

View File

@ -0,0 +1,72 @@
#include <Interpreters/InterpreterCreateUserQuery.h>
#include <Parsers/ASTCreateUserQuery.h>
#include <Interpreters/Context.h>
#include <Access/AccessControlManager.h>
#include <Access/User.h>
namespace DB
{
BlockIO InterpreterCreateUserQuery::execute()
{
const auto & query = query_ptr->as<const ASTCreateUserQuery &>();
auto & access_control = context.getAccessControlManager();
context.checkAccess(query.alter ? AccessType::ALTER_USER : AccessType::CREATE_USER);
if (query.alter)
{
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{
auto updated_user = typeid_cast<std::shared_ptr<User>>(entity->clone());
updateUserFromQuery(*updated_user, query);
return updated_user;
};
if (query.if_exists)
{
if (auto id = access_control.find<User>(query.name))
access_control.tryUpdate(*id, update_func);
}
else
access_control.update(access_control.getID<User>(query.name), update_func);
}
else
{
auto new_user = std::make_shared<User>();
updateUserFromQuery(*new_user, query);
if (query.if_not_exists)
access_control.tryInsert(new_user);
else if (query.or_replace)
access_control.insertOrReplace(new_user);
else
access_control.insert(new_user);
}
return {};
}
void InterpreterCreateUserQuery::updateUserFromQuery(User & user, const ASTCreateUserQuery & query)
{
if (query.alter)
{
if (!query.new_name.empty())
user.setName(query.new_name);
}
else
user.setName(query.name);
if (query.authentication)
user.authentication = *query.authentication;
if (query.hosts)
user.allowed_client_hosts = *query.hosts;
if (query.remove_hosts)
user.allowed_client_hosts.remove(*query.remove_hosts);
if (query.add_hosts)
user.allowed_client_hosts.add(*query.add_hosts);
if (query.profile)
user.profile = *query.profile;
}
}

View File

@ -0,0 +1,26 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class ASTCreateUserQuery;
struct User;
class InterpreterCreateUserQuery : public IInterpreter
{
public:
InterpreterCreateUserQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {}
BlockIO execute() override;
private:
void updateUserFromQuery(User & quota, const ASTCreateUserQuery & query);
ASTPtr query_ptr;
Context & context;
};
}

View File

@ -5,6 +5,7 @@
#include <Access/AccessFlags.h>
#include <Access/Quota.h>
#include <Access/RowPolicy.h>
#include <Access/User.h>
#include <boost/range/algorithm/transform.hpp>
@ -18,6 +19,16 @@ BlockIO InterpreterDropAccessEntityQuery::execute()
switch (query.kind)
{
case Kind::USER:
{
context.checkAccess(AccessType::DROP_USER);
if (query.if_exists)
access_control.tryRemove(access_control.find<User>(query.names));
else
access_control.remove(access_control.getIDs<User>(query.names));
return {};
}
case Kind::QUOTA:
{
context.checkAccess(AccessType::DROP_QUOTA);
@ -27,6 +38,7 @@ BlockIO InterpreterDropAccessEntityQuery::execute()
access_control.remove(access_control.getIDs<Quota>(query.names));
return {};
}
case Kind::ROW_POLICY:
{
context.checkAccess(AccessType::DROP_POLICY);

View File

@ -1,6 +1,7 @@
#include <Parsers/ASTAlterQuery.h>
#include <Parsers/ASTCheckQuery.h>
#include <Parsers/ASTCreateQuery.h>
#include <Parsers/ASTCreateUserQuery.h>
#include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Parsers/ASTDropAccessEntityQuery.h>
@ -14,6 +15,7 @@
#include <Parsers/ASTSetQuery.h>
#include <Parsers/ASTShowCreateAccessEntityQuery.h>
#include <Parsers/ASTShowProcesslistQuery.h>
#include <Parsers/ASTShowGrantsQuery.h>
#include <Parsers/ASTShowQuotasQuery.h>
#include <Parsers/ASTShowRowPoliciesQuery.h>
#include <Parsers/ASTShowTablesQuery.h>
@ -21,10 +23,12 @@
#include <Parsers/ASTExplainQuery.h>
#include <Parsers/TablePropertiesQueriesASTs.h>
#include <Parsers/ASTWatchQuery.h>
#include <Parsers/ASTGrantQuery.h>
#include <Interpreters/InterpreterAlterQuery.h>
#include <Interpreters/InterpreterCheckQuery.h>
#include <Interpreters/InterpreterCreateQuery.h>
#include <Interpreters/InterpreterCreateUserQuery.h>
#include <Interpreters/InterpreterCreateQuotaQuery.h>
#include <Interpreters/InterpreterCreateRowPolicyQuery.h>
#include <Interpreters/InterpreterDescribeQuery.h>
@ -43,12 +47,14 @@
#include <Interpreters/InterpreterShowCreateAccessEntityQuery.h>
#include <Interpreters/InterpreterShowCreateQuery.h>
#include <Interpreters/InterpreterShowProcesslistQuery.h>
#include <Interpreters/InterpreterShowGrantsQuery.h>
#include <Interpreters/InterpreterShowQuotasQuery.h>
#include <Interpreters/InterpreterShowRowPoliciesQuery.h>
#include <Interpreters/InterpreterShowTablesQuery.h>
#include <Interpreters/InterpreterSystemQuery.h>
#include <Interpreters/InterpreterUseQuery.h>
#include <Interpreters/InterpreterWatchQuery.h>
#include <Interpreters/InterpreterGrantQuery.h>
#include <Parsers/ASTSystemQuery.h>
@ -176,6 +182,10 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, Context &
{
return std::make_unique<InterpreterWatchQuery>(query, context);
}
else if (query->as<ASTCreateUserQuery>())
{
return std::make_unique<InterpreterCreateUserQuery>(query, context);
}
else if (query->as<ASTCreateQuotaQuery>())
{
return std::make_unique<InterpreterCreateQuotaQuery>(query, context);
@ -188,10 +198,18 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, Context &
{
return std::make_unique<InterpreterDropAccessEntityQuery>(query, context);
}
else if (query->as<ASTGrantQuery>())
{
return std::make_unique<InterpreterGrantQuery>(query, context);
}
else if (query->as<ASTShowCreateAccessEntityQuery>())
{
return std::make_unique<InterpreterShowCreateAccessEntityQuery>(query, context);
}
else if (query->as<ASTShowGrantsQuery>())
{
return std::make_unique<InterpreterShowGrantsQuery>(query, context);
}
else if (query->as<ASTShowQuotasQuery>())
{
return std::make_unique<InterpreterShowQuotasQuery>(query, context);

View File

@ -0,0 +1,57 @@
#include <Interpreters/InterpreterGrantQuery.h>
#include <Parsers/ASTGrantQuery.h>
#include <Parsers/ASTRoleList.h>
#include <Interpreters/Context.h>
#include <Access/AccessControlManager.h>
#include <Access/AccessRightsContext.h>
#include <Access/User.h>
namespace DB
{
BlockIO InterpreterGrantQuery::execute()
{
const auto & query = query_ptr->as<const ASTGrantQuery &>();
auto & access_control = context.getAccessControlManager();
context.getAccessRights()->checkGrantOption(query.access_rights_elements);
using Kind = ASTGrantQuery::Kind;
if (query.to_roles->all_roles)
throw Exception(
"Cannot " + String((query.kind == Kind::GRANT) ? "GRANT to" : "REVOKE from") + " ALL", ErrorCodes::NOT_IMPLEMENTED);
String current_database = context.getCurrentDatabase();
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{
auto updated_user = typeid_cast<std::shared_ptr<User>>(entity->clone());
if (query.kind == Kind::GRANT)
{
updated_user->access.grant(query.access_rights_elements, current_database);
if (query.grant_option)
updated_user->access_with_grant_option.grant(query.access_rights_elements, current_database);
}
else if (context.getSettingsRef().partial_revokes)
{
updated_user->access_with_grant_option.partialRevoke(query.access_rights_elements, current_database);
if (!query.grant_option)
updated_user->access.partialRevoke(query.access_rights_elements, current_database);
}
else
{
updated_user->access_with_grant_option.revoke(query.access_rights_elements, current_database);
if (!query.grant_option)
updated_user->access.revoke(query.access_rights_elements, current_database);
}
return updated_user;
};
std::vector<UUID> ids = access_control.getIDs<User>(query.to_roles->roles);
if (query.to_roles->current_user)
ids.push_back(context.getUserID());
access_control.update(ids, update_func);
return {};
}
}

View File

@ -0,0 +1,20 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class InterpreterGrantQuery : public IInterpreter
{
public:
InterpreterGrantQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {}
BlockIO execute() override;
private:
ASTPtr query_ptr;
Context & context;
};
}

View File

@ -212,15 +212,20 @@ static Context getSubqueryContext(const Context & context)
return subquery_context;
}
static void sanitizeBlock(Block & block)
static bool sanitizeBlock(Block & block)
{
for (auto & col : block)
{
if (!col.column)
{
if (isNotCreatable(col.type->getTypeId()))
return false;
col.column = col.type->createColumn();
}
else if (isColumnConst(*col.column) && !col.column->empty())
col.column = col.column->cloneEmpty();
}
return true;
}
InterpreterSelectQuery::InterpreterSelectQuery(
@ -324,7 +329,7 @@ InterpreterSelectQuery::InterpreterSelectQuery(
table_id = storage->getStorageID();
}
auto analyze = [&] ()
auto analyze = [&] (bool try_move_to_prewhere = true)
{
syntax_analyzer_result = SyntaxAnalyzer(*context, options).analyze(
query_ptr, source_header.getNamesAndTypesList(), required_result_column_names, storage, NamesAndTypesList());
@ -397,7 +402,7 @@ InterpreterSelectQuery::InterpreterSelectQuery(
throw Exception("PREWHERE is not supported if the table is filtered by row-level security expression", ErrorCodes::ILLEGAL_PREWHERE);
/// Calculate structure of the result.
result_header = getSampleBlockImpl();
result_header = getSampleBlockImpl(try_move_to_prewhere);
};
analyze();
@ -425,8 +430,13 @@ InterpreterSelectQuery::InterpreterSelectQuery(
query.setExpression(ASTSelectQuery::Expression::WHERE, makeASTFunction("and", query.prewhere()->clone(), query.where()->clone()));
need_analyze_again = true;
}
if (need_analyze_again)
analyze();
{
/// Do not try move conditions to PREWHERE for the second time.
/// Otherwise, we won't be able to fallback from inefficient PREWHERE to WHERE later.
analyze(/* try_move_to_prewhere = */ false);
}
/// If there is no WHERE, filter blocks as usual
if (query.prewhere() && !query.where())
@ -509,7 +519,7 @@ QueryPipeline InterpreterSelectQuery::executeWithProcessors()
}
Block InterpreterSelectQuery::getSampleBlockImpl()
Block InterpreterSelectQuery::getSampleBlockImpl(bool try_move_to_prewhere)
{
auto & query = getSelectQuery();
const Settings & settings = context->getSettingsRef();
@ -533,7 +543,7 @@ Block InterpreterSelectQuery::getSampleBlockImpl()
current_info.sets = query_analyzer->getPreparedSets();
/// Try transferring some condition from WHERE to PREWHERE if enabled and viable
if (settings.optimize_move_to_prewhere && query.where() && !query.prewhere() && !query.final())
if (settings.optimize_move_to_prewhere && try_move_to_prewhere && query.where() && !query.prewhere() && !query.final())
MergeTreeWhereOptimizer{current_info, *context, merge_tree,
syntax_analyzer_result->requiredSourceColumns(), log};
};
@ -608,18 +618,21 @@ Block InterpreterSelectQuery::getSampleBlockImpl()
/// Check if there is an ignore function. It's used for disabling constant folding in query
/// predicates because some performance tests use ignore function as a non-optimize guard.
static bool hasIgnore(const ExpressionActions & actions)
static bool allowEarlyConstantFolding(const ExpressionActions & actions, const Context & context)
{
if (!context.getSettingsRef().enable_early_constant_folding)
return false;
for (auto & action : actions.getActions())
{
if (action.type == action.APPLY_FUNCTION && action.function_base)
{
auto name = action.function_base->getName();
if (name == "ignore")
return true;
return false;
}
}
return false;
return true;
}
InterpreterSelectQuery::AnalysisResult
@ -726,15 +739,17 @@ InterpreterSelectQuery::analyzeExpressions(
res.prewhere_info = std::make_shared<PrewhereInfo>(
chain.steps.front().actions, query.prewhere()->getColumnName());
if (!hasIgnore(*res.prewhere_info->prewhere_actions))
if (allowEarlyConstantFolding(*res.prewhere_info->prewhere_actions, context))
{
Block before_prewhere_sample = source_header;
sanitizeBlock(before_prewhere_sample);
res.prewhere_info->prewhere_actions->execute(before_prewhere_sample);
auto & column_elem = before_prewhere_sample.getByName(query.prewhere()->getColumnName());
/// If the filter column is a constant, record it.
if (column_elem.column)
res.prewhere_constant_filter_description = ConstantFilterDescription(*column_elem.column);
if (sanitizeBlock(before_prewhere_sample))
{
res.prewhere_info->prewhere_actions->execute(before_prewhere_sample);
auto & column_elem = before_prewhere_sample.getByName(query.prewhere()->getColumnName());
/// If the filter column is a constant, record it.
if (column_elem.column)
res.prewhere_constant_filter_description = ConstantFilterDescription(*column_elem.column);
}
}
chain.addStep();
}
@ -756,19 +771,21 @@ InterpreterSelectQuery::analyzeExpressions(
where_step_num = chain.steps.size() - 1;
has_where = res.has_where = true;
res.before_where = chain.getLastActions();
if (!hasIgnore(*res.before_where))
if (allowEarlyConstantFolding(*res.before_where, context))
{
Block before_where_sample;
if (chain.steps.size() > 1)
before_where_sample = chain.steps[chain.steps.size() - 2].actions->getSampleBlock();
else
before_where_sample = source_header;
sanitizeBlock(before_where_sample);
res.before_where->execute(before_where_sample);
auto & column_elem = before_where_sample.getByName(query.where()->getColumnName());
/// If the filter column is a constant, record it.
if (column_elem.column)
res.where_constant_filter_description = ConstantFilterDescription(*column_elem.column);
if (sanitizeBlock(before_where_sample))
{
res.before_where->execute(before_where_sample);
auto & column_elem = before_where_sample.getByName(query.where()->getColumnName());
/// If the filter column is a constant, record it.
if (column_elem.column)
res.where_constant_filter_description = ConstantFilterDescription(*column_elem.column);
}
}
chain.addStep();
}

View File

@ -104,7 +104,7 @@ private:
ASTSelectQuery & getSelectQuery() { return query_ptr->as<ASTSelectQuery &>(); }
Block getSampleBlockImpl();
Block getSampleBlockImpl(bool try_move_to_prewhere);
struct Pipeline
{

View File

@ -1,5 +1,6 @@
#include <Interpreters/InterpreterShowCreateAccessEntityQuery.h>
#include <Interpreters/Context.h>
#include <Parsers/ASTCreateUserQuery.h>
#include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Parsers/ASTShowCreateAccessEntityQuery.h>
@ -9,6 +10,7 @@
#include <Parsers/parseQuery.h>
#include <Access/AccessControlManager.h>
#include <Access/QuotaContext.h>
#include <Access/User.h>
#include <Columns/ColumnString.h>
#include <DataStreams/OneBlockInputStream.h>
#include <DataTypes/DataTypeString.h>
@ -58,6 +60,7 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuery(const ASTShowCreat
using Kind = ASTShowCreateAccessEntityQuery::Kind;
switch (show_query.kind)
{
case Kind::USER: return getCreateUserQuery(show_query);
case Kind::QUOTA: return getCreateQuotaQuery(show_query);
case Kind::ROW_POLICY: return getCreateRowPolicyQuery(show_query);
}
@ -65,6 +68,27 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuery(const ASTShowCreat
}
ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateUserQuery(const ASTShowCreateAccessEntityQuery & show_query) const
{
UserPtr user;
if (show_query.current_user)
user = context.getUser();
else
user = context.getAccessControlManager().getUser(show_query.name);
auto create_query = std::make_shared<ASTCreateUserQuery>();
create_query->name = user->getName();
if (!user->allowed_client_hosts.containsAnyHost())
create_query->hosts = user->allowed_client_hosts;
if (!user->profile.empty())
create_query->profile = user->profile;
return create_query;
}
ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuotaQuery(const ASTShowCreateAccessEntityQuery & show_query) const
{
auto & access_control = context.getAccessControlManager();

View File

@ -29,6 +29,7 @@ private:
BlockInputStreamPtr executeImpl();
ASTPtr getCreateQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
ASTPtr getCreateUserQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
ASTPtr getCreateQuotaQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
ASTPtr getCreateRowPolicyQuery(const ASTShowCreateAccessEntityQuery & show_query) const;
};

View File

@ -0,0 +1,124 @@
#include <Interpreters/InterpreterShowGrantsQuery.h>
#include <Parsers/ASTShowGrantsQuery.h>
#include <Parsers/ASTGrantQuery.h>
#include <Parsers/ASTRoleList.h>
#include <Parsers/formatAST.h>
#include <Interpreters/Context.h>
#include <Columns/ColumnString.h>
#include <DataStreams/OneBlockInputStream.h>
#include <DataTypes/DataTypeString.h>
#include <Access/AccessControlManager.h>
#include <Access/User.h>
#include <boost/range/adaptor/map.hpp>
#include <boost/range/algorithm/copy.hpp>
namespace DB
{
namespace
{
std::vector<AccessRightsElements> groupByTable(AccessRightsElements && elements)
{
using Key = std::tuple<String, bool, String, bool>;
std::map<Key, AccessRightsElements> grouping_map;
for (auto & element : elements)
{
Key key(element.database, element.any_database, element.table, element.any_table);
grouping_map[key].emplace_back(std::move(element));
}
std::vector<AccessRightsElements> res;
res.reserve(grouping_map.size());
boost::range::copy(grouping_map | boost::adaptors::map_values, std::back_inserter(res));
return res;
}
struct GroupedGrantsAndPartialRevokes
{
std::vector<AccessRightsElements> grants;
std::vector<AccessRightsElements> partial_revokes;
};
GroupedGrantsAndPartialRevokes groupByTable(AccessRights::Elements && elements)
{
GroupedGrantsAndPartialRevokes res;
res.grants = groupByTable(std::move(elements.grants));
res.partial_revokes = groupByTable(std::move(elements.partial_revokes));
return res;
}
}
BlockIO InterpreterShowGrantsQuery::execute()
{
BlockIO res;
res.in = executeImpl();
return res;
}
BlockInputStreamPtr InterpreterShowGrantsQuery::executeImpl()
{
const auto & show_query = query_ptr->as<ASTShowGrantsQuery &>();
/// Build a create query.
ASTs grant_queries = getGrantQueries(show_query);
/// Build the result column.
MutableColumnPtr column = ColumnString::create();
std::stringstream grant_ss;
for (const auto & grant_query : grant_queries)
{
grant_ss.str("");
formatAST(*grant_query, grant_ss, false, true);
column->insert(grant_ss.str());
}
/// Prepare description of the result column.
std::stringstream desc_ss;
formatAST(show_query, desc_ss, false, true);
String desc = desc_ss.str();
String prefix = "SHOW ";
if (desc.starts_with(prefix))
desc = desc.substr(prefix.length()); /// `desc` always starts with "SHOW ", so we can trim this prefix.
return std::make_shared<OneBlockInputStream>(Block{{std::move(column), std::make_shared<DataTypeString>(), desc}});
}
ASTs InterpreterShowGrantsQuery::getGrantQueries(const ASTShowGrantsQuery & show_query) const
{
UserPtr user;
if (show_query.current_user)
user = context.getUser();
else
user = context.getAccessControlManager().getUser(show_query.name);
ASTs res;
for (bool grant_option : {true, false})
{
if (!grant_option && (user->access == user->access_with_grant_option))
continue;
const auto & access_rights = grant_option ? user->access_with_grant_option : user->access;
const auto grouped_elements = groupByTable(access_rights.getElements());
using Kind = ASTGrantQuery::Kind;
for (Kind kind : {Kind::GRANT, Kind::REVOKE})
{
for (const auto & elements : (kind == Kind::GRANT ? grouped_elements.grants : grouped_elements.partial_revokes))
{
auto grant_query = std::make_shared<ASTGrantQuery>();
grant_query->kind = kind;
grant_query->grant_option = grant_option;
grant_query->to_roles = std::make_shared<ASTRoleList>();
grant_query->to_roles->roles.push_back(user->getName());
grant_query->access_rights_elements = elements;
res.push_back(std::move(grant_query));
}
}
}
return res;
}
}

View File

@ -0,0 +1,26 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class ASTShowGrantsQuery;
class InterpreterShowGrantsQuery : public IInterpreter
{
public:
InterpreterShowGrantsQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {}
BlockIO execute() override;
private:
BlockInputStreamPtr executeImpl();
ASTs getGrantQueries(const ASTShowGrantsQuery & show_query) const;
ASTPtr query_ptr;
Context & context;
};
}

View File

@ -0,0 +1,187 @@
#include <Parsers/ASTCreateUserQuery.h>
#include <Common/quoteString.h>
namespace DB
{
namespace
{
void formatRenameTo(const String & new_name, const IAST::FormatSettings & settings)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " RENAME TO " << (settings.hilite ? IAST::hilite_none : "")
<< quoteString(new_name);
}
void formatAuthentication(const Authentication & authentication, const IAST::FormatSettings & settings)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " IDENTIFIED WITH " << (settings.hilite ? IAST::hilite_none : "");
switch (authentication.getType())
{
case Authentication::Type::NO_PASSWORD:
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "no_password" << (settings.hilite ? IAST::hilite_none : "");
break;
case Authentication::Type::PLAINTEXT_PASSWORD:
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "plaintext_password BY " << (settings.hilite ? IAST::hilite_none : "")
<< quoteString(authentication.getPassword());
break;
case Authentication::Type::SHA256_PASSWORD:
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "sha256_hash BY " << (settings.hilite ? IAST::hilite_none : "")
<< quoteString(authentication.getPasswordHashHex());
break;
case Authentication::Type::DOUBLE_SHA1_PASSWORD:
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "double_sha1_hash BY " << (settings.hilite ? IAST::hilite_none : "")
<< quoteString(authentication.getPasswordHashHex());
break;
}
}
void formatHosts(const char * prefix, const AllowedClientHosts & hosts, const IAST::FormatSettings & settings)
{
if (prefix)
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " " << prefix << " HOST "
<< (settings.hilite ? IAST::hilite_none : "");
else
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " HOST " << (settings.hilite ? IAST::hilite_none : "");
if (hosts.empty())
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "NONE" << (settings.hilite ? IAST::hilite_none : "");
return;
}
if (hosts.containsAnyHost())
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "ANY" << (settings.hilite ? IAST::hilite_none : "");
return;
}
bool need_comma = false;
if (hosts.containsLocalHost())
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "LOCAL" << (settings.hilite ? IAST::hilite_none : "");
}
const auto & addresses = hosts.getAddresses();
const auto & subnets = hosts.getSubnets();
if (!addresses.empty() || !subnets.empty())
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "IP " << (settings.hilite ? IAST::hilite_none : "");
bool need_comma2 = false;
for (const auto & address : addresses)
{
if (std::exchange(need_comma2, true))
settings.ostr << ", ";
settings.ostr << quoteString(address.toString());
}
for (const auto & subnet : subnets)
{
if (std::exchange(need_comma2, true))
settings.ostr << ", ";
settings.ostr << quoteString(subnet.toString());
}
}
const auto & names = hosts.getNames();
if (!names.empty())
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "NAME " << (settings.hilite ? IAST::hilite_none : "");
bool need_comma2 = false;
for (const auto & name : names)
{
if (std::exchange(need_comma2, true))
settings.ostr << ", ";
settings.ostr << quoteString(name);
}
}
const auto & name_regexps = hosts.getNameRegexps();
if (!name_regexps.empty())
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "NAME REGEXP " << (settings.hilite ? IAST::hilite_none : "");
bool need_comma2 = false;
for (const auto & host_regexp : name_regexps)
{
if (std::exchange(need_comma2, true))
settings.ostr << ", ";
settings.ostr << quoteString(host_regexp);
}
}
const auto & like_patterns = hosts.getLikePatterns();
if (!like_patterns.empty())
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "LIKE " << (settings.hilite ? IAST::hilite_none : "");
bool need_comma2 = false;
for (const auto & like_pattern : like_patterns)
{
if (std::exchange(need_comma2, true))
settings.ostr << ", ";
settings.ostr << quoteString(like_pattern);
}
}
}
void formatProfile(const String & profile_name, const IAST::FormatSettings & settings)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " PROFILE " << (settings.hilite ? IAST::hilite_none : "")
<< quoteString(profile_name);
}
}
String ASTCreateUserQuery::getID(char) const
{
return "CreateUserQuery";
}
ASTPtr ASTCreateUserQuery::clone() const
{
return std::make_shared<ASTCreateUserQuery>(*this);
}
void ASTCreateUserQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{
settings.ostr << (settings.hilite ? hilite_keyword : "") << (alter ? "ALTER USER" : "CREATE USER")
<< (settings.hilite ? hilite_none : "");
if (if_exists)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF EXISTS" << (settings.hilite ? hilite_none : "");
else if (if_not_exists)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF NOT EXISTS" << (settings.hilite ? hilite_none : "");
else if (or_replace)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " OR REPLACE" << (settings.hilite ? hilite_none : "");
settings.ostr << " " << backQuoteIfNeed(name);
if (!new_name.empty())
formatRenameTo(new_name, settings);
if (authentication)
formatAuthentication(*authentication, settings);
if (hosts)
formatHosts(nullptr, *hosts, settings);
if (add_hosts)
formatHosts("ADD", *add_hosts, settings);
if (remove_hosts)
formatHosts("REMOVE", *remove_hosts, settings);
if (profile)
formatProfile(*profile, settings);
}
}

View File

@ -0,0 +1,45 @@
#pragma once
#include <Parsers/IAST.h>
#include <Access/Authentication.h>
#include <Access/AllowedClientHosts.h>
namespace DB
{
/** CREATE USER [IF NOT EXISTS | OR REPLACE] name
* [IDENTIFIED [WITH {NO_PASSWORD|PLAINTEXT_PASSWORD|SHA256_PASSWORD|SHA256_HASH|DOUBLE_SHA1_PASSWORD|DOUBLE_SHA1_HASH}] BY {'password'|'hash'}]
* [HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE]
* [PROFILE 'profile_name']
*
* ALTER USER [IF EXISTS] name
* [RENAME TO new_name]
* [IDENTIFIED [WITH {PLAINTEXT_PASSWORD|SHA256_PASSWORD|DOUBLE_SHA1_PASSWORD}] BY {'password'|'hash'}]
* [[ADD|REMOVE] HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE]
* [PROFILE 'profile_name']
*/
class ASTCreateUserQuery : public IAST
{
public:
bool alter = false;
bool if_exists = false;
bool if_not_exists = false;
bool or_replace = false;
String name;
String new_name;
std::optional<Authentication> authentication;
std::optional<AllowedClientHosts> hosts;
std::optional<AllowedClientHosts> add_hosts;
std::optional<AllowedClientHosts> remove_hosts;
std::optional<String> profile;
String getID(char) const override;
ASTPtr clone() const override;
void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override;
};
}

View File

@ -12,6 +12,7 @@ namespace
{
switch (kind)
{
case Kind::USER: return "USER";
case Kind::QUOTA: return "QUOTA";
case Kind::ROW_POLICY: return "POLICY";
}

View File

@ -9,17 +9,21 @@ namespace DB
/** DROP QUOTA [IF EXISTS] name [,...]
* DROP [ROW] POLICY [IF EXISTS] name [,...] ON [database.]table [,...]
* DROP USER [IF EXISTS] name [,...]
*/
class ASTDropAccessEntityQuery : public IAST
{
public:
enum class Kind
{
USER,
QUOTA,
ROW_POLICY,
};
const Kind kind;
const char * const keyword;
bool if_exists = false;
Strings names;
std::vector<RowPolicy::FullNameParts> row_policies_names;

View File

@ -0,0 +1,118 @@
#include <Parsers/ASTGrantQuery.h>
#include <Parsers/ASTRoleList.h>
#include <Common/quoteString.h>
#include <boost/range/adaptor/map.hpp>
#include <boost/range/algorithm/sort.hpp>
#include <boost/range/algorithm_ext/push_back.hpp>
#include <map>
namespace DB
{
namespace
{
using KeywordToColumnsMap = std::map<std::string_view /* keyword */, std::vector<std::string_view> /* columns */>;
using TableToAccessMap = std::map<String /* database_and_table_name */, KeywordToColumnsMap>;
TableToAccessMap prepareTableToAccessMap(const AccessRightsElements & elements)
{
TableToAccessMap res;
for (const auto & element : elements)
{
String database_and_table_name;
if (element.any_database)
{
if (element.any_table)
database_and_table_name = "*.*";
else
database_and_table_name = "*." + backQuoteIfNeed(element.table);
}
else if (element.database.empty())
{
if (element.any_table)
database_and_table_name = "*";
else
database_and_table_name = backQuoteIfNeed(element.table);
}
else
{
if (element.any_table)
database_and_table_name = backQuoteIfNeed(element.database) + ".*";
else
database_and_table_name = backQuoteIfNeed(element.database) + "." + backQuoteIfNeed(element.table);
}
KeywordToColumnsMap & keyword_to_columns = res[database_and_table_name];
for (const auto & keyword : element.access_flags.toKeywords())
boost::range::push_back(keyword_to_columns[keyword], element.columns);
}
for (auto & keyword_to_columns : res | boost::adaptors::map_values)
{
for (auto & columns : keyword_to_columns | boost::adaptors::map_values)
boost::range::sort(columns);
}
return res;
}
void formatColumnNames(const std::vector<std::string_view> & columns, const IAST::FormatSettings & settings)
{
if (columns.empty())
return;
settings.ostr << "(";
bool need_comma_after_column_name = false;
for (const auto & column : columns)
{
if (std::exchange(need_comma_after_column_name, true))
settings.ostr << ", ";
settings.ostr << backQuoteIfNeed(column);
}
settings.ostr << ")";
}
}
String ASTGrantQuery::getID(char) const
{
return "GrantQuery";
}
ASTPtr ASTGrantQuery::clone() const
{
return std::make_shared<ASTGrantQuery>(*this);
}
void ASTGrantQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{
settings.ostr << (settings.hilite ? hilite_keyword : "") << ((kind == Kind::GRANT) ? "GRANT" : "REVOKE")
<< (settings.hilite ? hilite_none : "") << " ";
if (grant_option && (kind == Kind::REVOKE))
settings.ostr << (settings.hilite ? hilite_keyword : "") << "GRANT OPTION FOR " << (settings.hilite ? hilite_none : "");
bool need_comma = false;
for (const auto & [database_and_table, keyword_to_columns] : prepareTableToAccessMap(access_rights_elements))
{
for (const auto & [keyword, columns] : keyword_to_columns)
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? hilite_keyword : "") << keyword << (settings.hilite ? hilite_none : "");
formatColumnNames(columns, settings);
}
settings.ostr << (settings.hilite ? hilite_keyword : "") << " ON " << (settings.hilite ? hilite_none : "") << database_and_table;
}
settings.ostr << (settings.hilite ? hilite_keyword : "") << ((kind == Kind::GRANT) ? " TO " : " FROM ") << (settings.hilite ? hilite_none : "");
to_roles->format(settings);
if (grant_option && (kind == Kind::GRANT))
settings.ostr << (settings.hilite ? hilite_keyword : "") << " WITH GRANT OPTION" << (settings.hilite ? hilite_none : "");
}
}

View File

@ -0,0 +1,32 @@
#pragma once
#include <Parsers/IAST.h>
#include <Access/AccessRightsElement.h>
namespace DB
{
class ASTRoleList;
/** GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO user_name
* REVOKE access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO user_name
*/
class ASTGrantQuery : public IAST
{
public:
enum class Kind
{
GRANT,
REVOKE,
};
Kind kind = Kind::GRANT;
AccessRightsElements access_rights_elements;
std::shared_ptr<ASTRoleList> to_roles;
bool grant_option = false;
String getID(char) const override;
ASTPtr clone() const override;
void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override;
};
}

View File

@ -12,6 +12,7 @@ namespace
{
switch (kind)
{
case Kind::USER: return "USER";
case Kind::QUOTA: return "QUOTA";
case Kind::ROW_POLICY: return "POLICY";
}
@ -44,7 +45,11 @@ void ASTShowCreateAccessEntityQuery::formatQueryImpl(const FormatSettings & sett
<< "SHOW CREATE " << keyword
<< (settings.hilite ? hilite_none : "");
if (kind == Kind::ROW_POLICY)
if ((kind == Kind::USER) && current_user)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT_USER" << (settings.hilite ? hilite_none : "");
else if ((kind == Kind::QUOTA) && current_quota)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT" << (settings.hilite ? hilite_none : "");
else if (kind == Kind::ROW_POLICY)
{
const String & database = row_policy_name.database;
const String & table_name = row_policy_name.table_name;
@ -53,8 +58,6 @@ void ASTShowCreateAccessEntityQuery::formatQueryImpl(const FormatSettings & sett
<< (settings.hilite ? hilite_none : "") << (database.empty() ? String{} : backQuoteIfNeed(database) + ".")
<< backQuoteIfNeed(table_name);
}
else if ((kind == Kind::QUOTA) && current_quota)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT" << (settings.hilite ? hilite_none : "");
else
settings.ostr << " " << backQuoteIfNeed(name);
}

View File

@ -8,19 +8,23 @@ namespace DB
{
/** SHOW CREATE QUOTA [name | CURRENT]
* SHOW CREATE [ROW] POLICY name ON [database.]table
* SHOW CREATE USER [name | CURRENT_USER]
*/
class ASTShowCreateAccessEntityQuery : public ASTQueryWithOutput
{
public:
enum class Kind
{
USER,
QUOTA,
ROW_POLICY,
};
const Kind kind;
const char * const keyword;
String name;
bool current_quota = false;
bool current_user = false;
RowPolicy::FullNameParts row_policy_name;
ASTShowCreateAccessEntityQuery(Kind kind_);

View File

@ -0,0 +1,30 @@
#include <Parsers/ASTShowGrantsQuery.h>
#include <Common/quoteString.h>
namespace DB
{
String ASTShowGrantsQuery::getID(char) const
{
return "ShowGrantsQuery";
}
ASTPtr ASTShowGrantsQuery::clone() const
{
return std::make_shared<ASTShowGrantsQuery>(*this);
}
void ASTShowGrantsQuery::formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{
settings.ostr << (settings.hilite ? hilite_keyword : "") << "SHOW GRANTS FOR "
<< (settings.hilite ? hilite_none : "");
if (current_user)
settings.ostr << (settings.hilite ? hilite_keyword : "") << "CURRENT_USER"
<< (settings.hilite ? hilite_none : "");
else
settings.ostr << backQuoteIfNeed(name);
}
}

View File

@ -0,0 +1,20 @@
#pragma once
#include <Parsers/ASTQueryWithOutput.h>
namespace DB
{
/** SHOW GRANTS [FOR user_name]
*/
class ASTShowGrantsQuery : public ASTQueryWithOutput
{
public:
String name;
bool current_user = false;
String getID(char) const override;
ASTPtr clone() const override;
void formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override;
};
}

View File

@ -47,6 +47,13 @@ protected:
};
class ParserBareWord : public IParserBase
{
protected:
const char * getName() const override { return "bare word"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
/** An identifier, possibly containing a dot, for example, x_yz123 or `something special` or Hits.EventTime
*/
class ParserCompoundIdentifier : public IParserBase

View File

@ -304,6 +304,8 @@ Token Lexer::nextTokenImpl()
return Token(TokenType::Concatenation, token_begin, ++pos);
return Token(TokenType::ErrorSinglePipeMark, token_begin, pos);
}
case '@':
return Token(TokenType::At, token_begin, ++pos);
default:
if (isWordCharASCII(*pos))

View File

@ -48,6 +48,8 @@ namespace DB
M(GreaterOrEquals) \
M(Concatenation) /** String concatenation operator: || */ \
\
M(At) /** @. Used only for specifying user names. */ \
\
/** Order is important. EndOfStream goes after all usual tokens, and special error tokens goes after EndOfStream. */ \
\
M(EndOfStream) \

View File

@ -0,0 +1,297 @@
#include <Parsers/ParserCreateUserQuery.h>
#include <Parsers/ASTCreateUserQuery.h>
#include <Parsers/CommonParsers.h>
#include <Parsers/parseUserName.h>
#include <Parsers/parseIdentifierOrStringLiteral.h>
#include <Parsers/ExpressionElementParsers.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTRoleList.h>
#include <ext/range.h>
#include <boost/algorithm/string/predicate.hpp>
namespace DB
{
namespace ErrorCodes
{
extern const int SYNTAX_ERROR;
}
namespace
{
bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_name, String & new_host_pattern)
{
return IParserBase::wrapParseImpl(pos, [&]
{
if (!new_name.empty())
return false;
if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected))
return false;
return parseUserName(pos, expected, new_name, new_host_pattern);
});
}
bool parsePassword(IParserBase::Pos & pos, Expected & expected, String & password)
{
ASTPtr ast;
if (!ParserStringLiteral{}.parse(pos, ast, expected))
return false;
password = ast->as<const ASTLiteral &>().value.safeGet<String>();
return true;
}
bool parseAuthentication(IParserBase::Pos & pos, Expected & expected, std::optional<Authentication> & authentication)
{
return IParserBase::wrapParseImpl(pos, [&]
{
if (authentication)
return false;
if (!ParserKeyword{"IDENTIFIED"}.ignore(pos, expected))
return false;
if (ParserKeyword{"WITH"}.ignore(pos, expected))
{
if (ParserKeyword{"NO_PASSWORD"}.ignore(pos, expected))
{
authentication = Authentication{Authentication::NO_PASSWORD};
}
else if (ParserKeyword{"PLAINTEXT_PASSWORD"}.ignore(pos, expected))
{
String password;
if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, password))
return false;
authentication = Authentication{Authentication::PLAINTEXT_PASSWORD};
authentication->setPassword(password);
}
else if (ParserKeyword{"SHA256_PASSWORD"}.ignore(pos, expected))
{
String password;
if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, password))
return false;
authentication = Authentication{Authentication::SHA256_PASSWORD};
authentication->setPassword(password);
}
else if (ParserKeyword{"SHA256_HASH"}.ignore(pos, expected))
{
String hash;
if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, hash))
return false;
authentication = Authentication{Authentication::SHA256_PASSWORD};
authentication->setPasswordHashHex(hash);
}
else if (ParserKeyword{"DOUBLE_SHA1_PASSWORD"}.ignore(pos, expected))
{
String password;
if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, password))
return false;
authentication = Authentication{Authentication::DOUBLE_SHA1_PASSWORD};
authentication->setPassword(password);
}
else if (ParserKeyword{"DOUBLE_SHA1_HASH"}.ignore(pos, expected))
{
String hash;
if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, hash))
return false;
authentication = Authentication{Authentication::DOUBLE_SHA1_PASSWORD};
authentication->setPasswordHashHex(hash);
}
else
return false;
}
else
{
String password;
if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, password))
return false;
authentication = Authentication{Authentication::SHA256_PASSWORD};
authentication->setPassword(password);
}
return true;
});
}
bool parseHosts(IParserBase::Pos & pos, Expected & expected, const char * prefix, std::optional<AllowedClientHosts> & hosts)
{
return IParserBase::wrapParseImpl(pos, [&]
{
if (prefix && !ParserKeyword{prefix}.ignore(pos, expected))
return false;
if (!ParserKeyword{"HOST"}.ignore(pos, expected))
return false;
if (ParserKeyword{"ANY"}.ignore(pos, expected))
{
if (!hosts)
hosts.emplace();
hosts->addAnyHost();
return true;
}
if (ParserKeyword{"NONE"}.ignore(pos, expected))
{
if (!hosts)
hosts.emplace();
return true;
}
do
{
if (ParserKeyword{"LOCAL"}.ignore(pos, expected))
{
if (!hosts)
hosts.emplace();
hosts->addLocalHost();
}
else if (ParserKeyword{"NAME REGEXP"}.ignore(pos, expected))
{
ASTPtr ast;
if (!ParserStringLiteral{}.parse(pos, ast, expected))
return false;
if (!hosts)
hosts.emplace();
hosts->addNameRegexp(ast->as<const ASTLiteral &>().value.safeGet<String>());
}
else if (ParserKeyword{"NAME"}.ignore(pos, expected))
{
ASTPtr ast;
if (!ParserStringLiteral{}.parse(pos, ast, expected))
return false;
if (!hosts)
hosts.emplace();
hosts->addName(ast->as<const ASTLiteral &>().value.safeGet<String>());
}
else if (ParserKeyword{"IP"}.ignore(pos, expected))
{
ASTPtr ast;
if (!ParserStringLiteral{}.parse(pos, ast, expected))
return false;
if (!hosts)
hosts.emplace();
hosts->addSubnet(ast->as<const ASTLiteral &>().value.safeGet<String>());
}
else if (ParserKeyword{"LIKE"}.ignore(pos, expected))
{
ASTPtr ast;
if (!ParserStringLiteral{}.parse(pos, ast, expected))
return false;
if (!hosts)
hosts.emplace();
hosts->addLikePattern(ast->as<const ASTLiteral &>().value.safeGet<String>());
}
else
return false;
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
return true;
});
}
bool parseProfileName(IParserBase::Pos & pos, Expected & expected, std::optional<String> & profile)
{
return IParserBase::wrapParseImpl(pos, [&]
{
if (profile)
return false;
if (!ParserKeyword{"PROFILE"}.ignore(pos, expected))
return false;
ASTPtr ast;
if (!ParserStringLiteral{}.parse(pos, ast, expected))
return false;
profile = ast->as<const ASTLiteral &>().value.safeGet<String>();
return true;
});
}
}
bool ParserCreateUserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
bool alter;
if (ParserKeyword{"CREATE USER"}.ignore(pos, expected))
alter = false;
else if (ParserKeyword{"ALTER USER"}.ignore(pos, expected))
alter = true;
else
return false;
bool if_exists = false;
bool if_not_exists = false;
bool or_replace = false;
if (alter)
{
if (ParserKeyword{"IF EXISTS"}.ignore(pos, expected))
if_exists = true;
}
else
{
if (ParserKeyword{"IF NOT EXISTS"}.ignore(pos, expected))
if_not_exists = true;
else if (ParserKeyword{"OR REPLACE"}.ignore(pos, expected))
or_replace = true;
}
String name;
String host_pattern;
if (!parseUserName(pos, expected, name, host_pattern))
return false;
String new_name;
String new_host_pattern;
std::optional<Authentication> authentication;
std::optional<AllowedClientHosts> hosts;
std::optional<AllowedClientHosts> add_hosts;
std::optional<AllowedClientHosts> remove_hosts;
std::optional<String> profile;
while (parseAuthentication(pos, expected, authentication)
|| parseHosts(pos, expected, nullptr, hosts)
|| parseProfileName(pos, expected, profile)
|| (alter && parseRenameTo(pos, expected, new_name, new_host_pattern))
|| (alter && parseHosts(pos, expected, "ADD", add_hosts))
|| (alter && parseHosts(pos, expected, "REMOVE", remove_hosts)))
;
if (!hosts)
{
if (!alter)
hosts.emplace().addLikePattern(host_pattern);
else if (alter && !new_name.empty())
hosts.emplace().addLikePattern(new_host_pattern);
}
auto query = std::make_shared<ASTCreateUserQuery>();
node = query;
query->alter = alter;
query->if_exists = if_exists;
query->if_not_exists = if_not_exists;
query->or_replace = or_replace;
query->name = std::move(name);
query->new_name = std::move(new_name);
query->authentication = std::move(authentication);
query->hosts = std::move(hosts);
query->add_hosts = std::move(add_hosts);
query->remove_hosts = std::move(remove_hosts);
query->profile = std::move(profile);
return true;
}
}

View File

@ -0,0 +1,26 @@
#pragma once
#include <Parsers/IParserBase.h>
namespace DB
{
/** Parses queries like
* CREATE USER [IF NOT EXISTS | OR REPLACE] name
* [IDENTIFIED [WITH {NO_PASSWORD|PLAINTEXT_PASSWORD|SHA256_PASSWORD|SHA256_HASH|DOUBLE_SHA1_PASSWORD|DOUBLE_SHA1_HASH}] BY {'password'|'hash'}]
* [HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE]
* [PROFILE 'profile_name']
*
* ALTER USER [IF EXISTS] name
* [RENAME TO new_name]
* [IDENTIFIED [WITH {PLAINTEXT_PASSWORD|SHA256_PASSWORD|DOUBLE_SHA1_PASSWORD}] BY {'password'|'hash'}]
* [[ADD|REMOVE] HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE]
* [PROFILE 'profile_name']
*/
class ParserCreateUserQuery : public IParserBase
{
protected:
const char * getName() const override { return "CREATE USER or ALTER USER query"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
}

View File

@ -3,6 +3,7 @@
#include <Parsers/CommonParsers.h>
#include <Parsers/parseIdentifierOrStringLiteral.h>
#include <Parsers/parseDatabaseAndTableName.h>
#include <Parsers/parseUserName.h>
#include <Access/Quota.h>
@ -23,6 +24,37 @@ namespace
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
return true;
}
bool parseRowPolicyNames(IParserBase::Pos & pos, Expected & expected, std::vector<RowPolicy::FullNameParts> & row_policies_names)
{
do
{
Strings policy_names;
if (!parseNames(pos, expected, policy_names))
return false;
String database, table_name;
if (!ParserKeyword{"ON"}.ignore(pos, expected) || !parseDatabaseAndTableName(pos, expected, database, table_name))
return false;
for (const String & policy_name : policy_names)
row_policies_names.push_back({database, table_name, policy_name});
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
return true;
}
bool parseUserNames(IParserBase::Pos & pos, Expected & expected, Strings & names)
{
do
{
String name;
if (!parseUserName(pos, expected, name))
return false;
names.push_back(std::move(name));
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
return true;
}
}
@ -37,6 +69,8 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected &
kind = Kind::QUOTA;
else if (ParserKeyword{"POLICY"}.ignore(pos, expected) || ParserKeyword{"ROW POLICY"}.ignore(pos, expected))
kind = Kind::ROW_POLICY;
else if (ParserKeyword{"USER"}.ignore(pos, expected))
kind = Kind::USER;
else
return false;
@ -47,23 +81,19 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected &
Strings names;
std::vector<RowPolicy::FullNameParts> row_policies_names;
if (kind == Kind::ROW_POLICY)
if (kind == Kind::USER)
{
do
{
Strings policy_names;
if (!parseNames(pos, expected, policy_names))
return false;
String database, table_name;
if (!ParserKeyword{"ON"}.ignore(pos, expected) || !parseDatabaseAndTableName(pos, expected, database, table_name))
return false;
for (const String & policy_name : policy_names)
row_policies_names.push_back({database, table_name, policy_name});
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
if (!parseUserNames(pos, expected, names))
return false;
}
else if (kind == Kind::ROW_POLICY)
{
if (!parseRowPolicyNames(pos, expected, row_policies_names))
return false;
}
else
{
assert(kind == Kind::QUOTA);
if (!parseNames(pos, expected, names))
return false;
}

View File

@ -7,6 +7,8 @@ namespace DB
{
/** Parses queries like
* DROP QUOTA [IF EXISTS] name [,...]
* DROP [ROW] POLICY [IF EXISTS] name [,...] ON [database.]table [,...]
* DROP USER [IF EXISTS] name [,...]
*/
class ParserDropAccessEntityQuery : public IParserBase
{

View File

@ -0,0 +1,232 @@
#include <Parsers/ParserGrantQuery.h>
#include <Parsers/ASTGrantQuery.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTRoleList.h>
#include <Parsers/CommonParsers.h>
#include <Parsers/ExpressionElementParsers.h>
#include <Parsers/ParserRoleList.h>
#include <boost/algorithm/string/predicate.hpp>
namespace DB
{
namespace
{
bool parseAccessFlags(IParser::Pos & pos, Expected & expected, AccessFlags & access_flags)
{
auto is_one_of_access_type_words = [](IParser::Pos & pos_)
{
if (pos_->type != TokenType::BareWord)
return false;
std::string_view word{pos_->begin, pos_->size()};
if (boost::iequals(word, "ON") || boost::iequals(word, "TO") || boost::iequals(word, "FROM"))
return false;
return true;
};
if (!is_one_of_access_type_words(pos))
{
expected.add(pos, "access type");
return false;
}
String str;
do
{
if (!str.empty())
str += " ";
std::string_view word{pos->begin, pos->size()};
str += std::string_view(pos->begin, pos->size());
++pos;
}
while (is_one_of_access_type_words(pos));
if (pos->type == TokenType::OpeningRoundBracket)
{
auto old_pos = pos;
++pos;
if (pos->type == TokenType::ClosingRoundBracket)
{
++pos;
str += "()";
}
else
pos = old_pos;
}
access_flags = AccessFlags{str};
return true;
}
bool parseColumnNames(IParser::Pos & pos, Expected & expected, Strings & columns)
{
if (!ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected))
return false;
do
{
ASTPtr column_ast;
if (!ParserIdentifier().parse(pos, column_ast, expected))
return false;
columns.push_back(getIdentifierName(column_ast));
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
return ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected);
}
bool parseDatabaseAndTableNameOrMaybeAsterisks(
IParser::Pos & pos, Expected & expected, String & database_name, bool & any_database, String & table_name, bool & any_table)
{
ASTPtr ast[2];
if (ParserToken{TokenType::Asterisk}.ignore(pos, expected))
{
if (ParserToken{TokenType::Dot}.ignore(pos, expected))
{
if (!ParserToken{TokenType::Asterisk}.ignore(pos, expected))
return false;
/// *.* (any table in any database)
any_database = true;
any_table = true;
return true;
}
else
{
/// * (any table in the current database)
any_database = false;
database_name = "";
any_table = true;
return true;
}
}
else if (ParserIdentifier().parse(pos, ast[0], expected))
{
if (ParserToken{TokenType::Dot}.ignore(pos, expected))
{
if (ParserToken{TokenType::Asterisk}.ignore(pos, expected))
{
/// <database_name>.*
any_database = false;
database_name = getIdentifierName(ast[0]);
any_table = true;
return true;
}
else if (ParserIdentifier().parse(pos, ast[1], expected))
{
/// <database_name>.<table_name>
any_database = false;
database_name = getIdentifierName(ast[0]);
any_table = false;
table_name = getIdentifierName(ast[1]);
return true;
}
else
return false;
}
else
{
/// <table_name> - the current database, specified table
any_database = false;
database_name = "";
table_name = getIdentifierName(ast[0]);
return true;
}
}
else
return false;
}
}
bool ParserGrantQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
using Kind = ASTGrantQuery::Kind;
Kind kind;
if (ParserKeyword{"GRANT"}.ignore(pos, expected))
kind = Kind::GRANT;
else if (ParserKeyword{"REVOKE"}.ignore(pos, expected))
kind = Kind::REVOKE;
else
return false;
bool grant_option = false;
if (kind == Kind::REVOKE)
{
if (ParserKeyword{"GRANT OPTION FOR"}.ignore(pos, expected))
grant_option = true;
}
AccessRightsElements elements;
do
{
std::vector<std::pair<AccessFlags, Strings>> access_and_columns;
do
{
AccessFlags access_flags;
if (!parseAccessFlags(pos, expected, access_flags))
return false;
Strings columns;
parseColumnNames(pos, expected, columns);
access_and_columns.emplace_back(access_flags, std::move(columns));
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
if (!ParserKeyword{"ON"}.ignore(pos, expected))
return false;
String database_name, table_name;
bool any_database = false, any_table = false;
if (!parseDatabaseAndTableNameOrMaybeAsterisks(pos, expected, database_name, any_database, table_name, any_table))
return false;
for (auto & [access_flags, columns] : access_and_columns)
{
AccessRightsElement element;
element.access_flags = access_flags;
element.any_column = columns.empty();
element.columns = std::move(columns);
element.any_database = any_database;
element.database = database_name;
element.any_table = any_table;
element.table = table_name;
elements.emplace_back(std::move(element));
}
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
ASTPtr to_roles;
if (kind == Kind::GRANT)
{
if (!ParserKeyword{"TO"}.ignore(pos, expected))
return false;
}
else
{
if (!ParserKeyword{"FROM"}.ignore(pos, expected))
return false;
}
if (!ParserRoleList{}.parse(pos, to_roles, expected))
return false;
if (kind == Kind::GRANT)
{
if (ParserKeyword{"WITH GRANT OPTION"}.ignore(pos, expected))
grant_option = true;
}
auto query = std::make_shared<ASTGrantQuery>();
node = query;
query->kind = kind;
query->access_rights_elements = std::move(elements);
query->to_roles = std::static_pointer_cast<ASTRoleList>(to_roles);
query->grant_option = grant_option;
return true;
}
}

View File

@ -0,0 +1,18 @@
#pragma once
#include <Parsers/IParserBase.h>
namespace DB
{
/** Parses queries like
* GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO user_name
* REVOKE access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO user_name
*/
class ParserGrantQuery : public IParserBase
{
protected:
const char * getName() const override { return "GRANT or REVOKE query"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
}

View File

@ -9,9 +9,11 @@
#include <Parsers/ParserSetQuery.h>
#include <Parsers/ParserAlterQuery.h>
#include <Parsers/ParserSystemQuery.h>
#include <Parsers/ParserCreateUserQuery.h>
#include <Parsers/ParserCreateQuotaQuery.h>
#include <Parsers/ParserCreateRowPolicyQuery.h>
#include <Parsers/ParserDropAccessEntityQuery.h>
#include <Parsers/ParserGrantQuery.h>
namespace DB
@ -25,18 +27,22 @@ bool ParserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
ParserUseQuery use_p;
ParserSetQuery set_p;
ParserSystemQuery system_p;
ParserCreateUserQuery create_user_p;
ParserCreateQuotaQuery create_quota_p;
ParserCreateRowPolicyQuery create_row_policy_p;
ParserDropAccessEntityQuery drop_access_entity_p;
ParserGrantQuery grant_p;
bool res = query_with_output_p.parse(pos, node, expected)
|| insert_p.parse(pos, node, expected)
|| use_p.parse(pos, node, expected)
|| set_p.parse(pos, node, expected)
|| system_p.parse(pos, node, expected)
|| create_user_p.parse(pos, node, expected)
|| create_quota_p.parse(pos, node, expected)
|| create_row_policy_p.parse(pos, node, expected)
|| drop_access_entity_p.parse(pos, node, expected);
|| drop_access_entity_p.parse(pos, node, expected)
|| grant_p.parse(pos, node, expected);
return res;
}

View File

@ -14,6 +14,7 @@
#include <Parsers/ParserWatchQuery.h>
#include <Parsers/ParserSetQuery.h>
#include <Parsers/ASTExplainQuery.h>
#include <Parsers/ParserShowGrantsQuery.h>
#include <Parsers/ParserShowCreateAccessEntityQuery.h>
#include <Parsers/ParserShowQuotasQuery.h>
#include <Parsers/ParserShowRowPoliciesQuery.h>
@ -38,6 +39,7 @@ bool ParserQueryWithOutput::parseImpl(Pos & pos, ASTPtr & node, Expected & expec
ParserKillQueryQuery kill_query_p;
ParserWatchQuery watch_p;
ParserShowCreateAccessEntityQuery show_create_access_entity_p;
ParserShowGrantsQuery show_grants_p;
ParserShowQuotasQuery show_quotas_p;
ParserShowRowPoliciesQuery show_row_policies_p;
@ -68,6 +70,7 @@ bool ParserQueryWithOutput::parseImpl(Pos & pos, ASTPtr & node, Expected & expec
|| kill_query_p.parse(pos, query, expected)
|| optimize_p.parse(pos, query, expected)
|| watch_p.parse(pos, query, expected)
|| show_grants_p.parse(pos, query, expected)
|| show_quotas_p.parse(pos, query, expected)
|| show_row_policies_p.parse(pos, query, expected);

View File

@ -1,7 +1,7 @@
#include <Parsers/ParserRoleList.h>
#include <Parsers/CommonParsers.h>
#include <Parsers/ASTRoleList.h>
#include <Parsers/parseIdentifierOrStringLiteral.h>
#include <Parsers/parseUserName.h>
#include <boost/range/algorithm/find.hpp>
@ -47,7 +47,7 @@ bool ParserRoleList::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
else
{
String name;
if (!parseIdentifierOrStringLiteral(pos, expected, name))
if (!parseUserName(pos, expected, name))
return false;
if (except_mode && (boost::range::find(roles, name) == roles.end()))
except_roles.push_back(name);

View File

@ -3,6 +3,7 @@
#include <Parsers/CommonParsers.h>
#include <Parsers/parseIdentifierOrStringLiteral.h>
#include <Parsers/parseDatabaseAndTableName.h>
#include <Parsers/parseUserName.h>
#include <assert.h>
@ -15,7 +16,9 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe
using Kind = ASTShowCreateAccessEntityQuery::Kind;
Kind kind;
if (ParserKeyword{"QUOTA"}.ignore(pos, expected))
if (ParserKeyword{"USER"}.ignore(pos, expected))
kind = Kind::USER;
else if (ParserKeyword{"QUOTA"}.ignore(pos, expected))
kind = Kind::QUOTA;
else if (ParserKeyword{"POLICY"}.ignore(pos, expected) || ParserKeyword{"ROW POLICY"}.ignore(pos, expected))
kind = Kind::ROW_POLICY;
@ -24,9 +27,15 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe
String name;
bool current_quota = false;
bool current_user = false;
RowPolicy::FullNameParts row_policy_name;
if (kind == Kind::ROW_POLICY)
if (kind == Kind::USER)
{
if (!parseUserNameOrCurrentUserTag(pos, expected, name, current_user))
current_user = true;
}
else if (kind == Kind::ROW_POLICY)
{
String & database = row_policy_name.database;
String & table_name = row_policy_name.table_name;

View File

@ -0,0 +1,33 @@
#include <Parsers/ParserShowGrantsQuery.h>
#include <Parsers/ASTShowGrantsQuery.h>
#include <Parsers/CommonParsers.h>
#include <Parsers/parseUserName.h>
namespace DB
{
bool ParserShowGrantsQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
if (!ParserKeyword{"SHOW GRANTS"}.ignore(pos, expected))
return false;
String name;
bool current_user = false;
if (ParserKeyword{"FOR"}.ignore(pos, expected))
{
if (!parseUserNameOrCurrentUserTag(pos, expected, name, current_user))
return false;
}
else
current_user = true;
auto query = std::make_shared<ASTShowGrantsQuery>();
node = query;
query->name = name;
query->current_user = current_user;
return true;
}
}

View File

@ -0,0 +1,17 @@
#pragma once
#include <Parsers/IParserBase.h>
namespace DB
{
/** Parses queries like
* SHOW GRANTS [FOR user_name]
*/
class ParserShowGrantsQuery : public IParserBase
{
protected:
const char * getName() const override { return "SHOW GRANTS query"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
}

View File

@ -0,0 +1,73 @@
#include <Parsers/parseUserName.h>
#include <Parsers/parseIdentifierOrStringLiteral.h>
#include <Parsers/CommonParsers.h>
#include <boost/algorithm/string.hpp>
namespace DB
{
namespace
{
bool parseUserNameImpl(IParser::Pos & pos, Expected & expected, String & user_name, String * host_like_pattern)
{
String name;
if (!parseIdentifierOrStringLiteral(pos, expected, name))
return false;
boost::algorithm::trim(name);
String pattern = "@";
if (ParserToken{TokenType::At}.ignore(pos, expected))
{
if (!parseIdentifierOrStringLiteral(pos, expected, pattern))
return false;
boost::algorithm::trim(pattern);
}
if (pattern != "@")
name += '@' + pattern;
user_name = std::move(name);
if (host_like_pattern)
*host_like_pattern = std::move(pattern);
return true;
}
}
bool parseUserName(IParser::Pos & pos, Expected & expected, String & result)
{
return parseUserNameImpl(pos, expected, result, nullptr);
}
bool parseUserName(IParser::Pos & pos, Expected & expected, String & result, String & host_like_pattern)
{
return parseUserNameImpl(pos, expected, result, &host_like_pattern);
}
bool parseUserNameOrCurrentUserTag(IParser::Pos & pos, Expected & expected, String & user_name, bool & current_user)
{
if (ParserKeyword{"CURRENT_USER"}.ignore(pos, expected) || ParserKeyword{"currentUser"}.ignore(pos, expected))
{
if (ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected))
{
if (!ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected))
return false;
}
current_user = true;
return true;
}
if (parseUserName(pos, expected, user_name))
{
current_user = false;
return true;
}
return false;
}
}

View File

@ -0,0 +1,25 @@
#pragma once
#include <Parsers/IParser.h>
namespace DB
{
/// Parses a user name. It can be a simple string or identifier or something like `name@host`.
/// In the last case `host` specifies the hosts user is allowed to connect from.
/// The `host` can be an ip address, ip subnet, or a host name.
/// The % and _ wildcard characters are permitted in `host`.
/// These have the same meaning as for pattern-matching operations performed with the LIKE operator.
bool parseUserName(IParser::Pos & pos, Expected & expected, String & user_name, String & host_like_pattern);
bool parseUserName(IParser::Pos & pos, Expected & expected, String & user_name);
/// Parses either a user name or the 'CURRENT_USER' keyword (or some of the aliases).
bool parseUserNameOrCurrentUserTag(IParser::Pos & pos, Expected & expected, String & user_name, bool & current_user);
/// Parses a role name. It follows the same rules as a user name, but allowed hosts are never checked
/// (because roles are not used to connect to server).
inline bool parseRoleName(IParser::Pos & pos, Expected & expected, String & role_name)
{
return parseUserName(pos, expected, role_name);
}
}

View File

@ -55,6 +55,18 @@ public:
return remove(MergeTreePartInfo::fromPartName(part_name, format_version));
}
/// Remove part and all covered parts from active set
bool removePartAndCoveredParts(const String & part_name)
{
Strings parts_covered_by = getPartsCoveredBy(MergeTreePartInfo::fromPartName(part_name, format_version));
bool result = true;
result &= remove(part_name);
for (const auto & part : parts_covered_by)
result &= remove(part);
return result;
}
/// If not found, return an empty string.
String getContainingPart(const MergeTreePartInfo & part_info) const;
String getContainingPart(const String & name) const;
@ -66,6 +78,11 @@ public:
size_t size() const;
void clear()
{
part_info_to_name.clear();
}
MergeTreeDataFormatVersion getFormatVersion() const { return format_version; }
private:

View File

@ -431,9 +431,11 @@ void MergeTreeRangeReader::ReadResult::setFilter(const ColumnPtr & new_filter)
}
ConstantFilterDescription const_description(*new_filter);
if (const_description.always_false)
if (const_description.always_true)
setFilterConstTrue();
else if (const_description.always_false)
clear();
else if (!const_description.always_true)
else
{
FilterDescription filter_description(*new_filter);
filter_holder = filter_description.data_holder ? filter_description.data_holder : new_filter;

View File

@ -130,7 +130,7 @@ void ReplicatedMergeTreeQueue::insertUnlocked(
for (const String & virtual_part_name : entry->getVirtualPartNames())
{
virtual_parts.add(virtual_part_name);
updateMutationsPartsToDo(virtual_part_name, /* add = */ true);
addPartToMutations(virtual_part_name);
}
/// Put 'DROP PARTITION' entries at the beginning of the queue not to make superfluous fetches of parts that will be eventually deleted
@ -200,12 +200,16 @@ void ReplicatedMergeTreeQueue::updateStateOnQueueEntryRemoval(
for (const String & virtual_part_name : entry->getVirtualPartNames())
{
Strings replaced_parts;
/// In most cases we will replace only current parts, but sometimes
/// we can even replace virtual parts. For example when we failed to
/// GET source part and dowloaded merged/mutated part instead.
current_parts.add(virtual_part_name, &replaced_parts);
virtual_parts.add(virtual_part_name, &replaced_parts);
/// Each part from `replaced_parts` should become Obsolete as a result of executing the entry.
/// So it is one less part to mutate for each mutation with block number greater than part_info.getDataVersion()
/// So it is one less part to mutate for each mutation with block number greater or equal than part_info.getDataVersion()
for (const String & replaced_part_name : replaced_parts)
updateMutationsPartsToDo(replaced_part_name, /* add = */ false);
removePartFromMutations(replaced_part_name);
}
String drop_range_part_name;
@ -226,13 +230,13 @@ void ReplicatedMergeTreeQueue::updateStateOnQueueEntryRemoval(
{
/// Because execution of the entry is unsuccessful, `virtual_part_name` will never appear
/// so we won't need to mutate it.
updateMutationsPartsToDo(virtual_part_name, /* add = */ false);
removePartFromMutations(virtual_part_name);
}
}
}
void ReplicatedMergeTreeQueue::updateMutationsPartsToDo(const String & part_name, bool add)
void ReplicatedMergeTreeQueue::removePartFromMutations(const String & part_name)
{
auto part_info = MergeTreePartInfo::fromPartName(part_name, format_version);
auto in_partition = mutations_by_partition.find(part_info.partition_id);
@ -241,15 +245,15 @@ void ReplicatedMergeTreeQueue::updateMutationsPartsToDo(const String & part_name
bool some_mutations_are_probably_done = false;
auto from_it = in_partition->second.upper_bound(part_info.getDataVersion());
auto from_it = in_partition->second.lower_bound(part_info.getDataVersion());
for (auto it = from_it; it != in_partition->second.end(); ++it)
{
MutationStatus & status = *it->second;
status.parts_to_do += (add ? +1 : -1);
if (status.parts_to_do <= 0)
status.parts_to_do.removePartAndCoveredParts(part_name);
if (status.parts_to_do.size() == 0)
some_mutations_are_probably_done = true;
if (!add && !status.latest_failed_part.empty() && part_info.contains(status.latest_failed_part_info))
if (!status.latest_failed_part.empty() && part_info.contains(status.latest_failed_part_info))
{
status.latest_failed_part.clear();
status.latest_failed_part_info = MergeTreePartInfo();
@ -262,6 +266,20 @@ void ReplicatedMergeTreeQueue::updateMutationsPartsToDo(const String & part_name
storage.mutations_finalizing_task->schedule();
}
void ReplicatedMergeTreeQueue::addPartToMutations(const String & part_name)
{
auto part_info = MergeTreePartInfo::fromPartName(part_name, format_version);
auto in_partition = mutations_by_partition.find(part_info.partition_id);
if (in_partition == mutations_by_partition.end())
return;
auto from_it = in_partition->second.upper_bound(part_info.getDataVersion());
for (auto it = from_it; it != in_partition->second.end(); ++it)
{
MutationStatus & status = *it->second;
status.parts_to_do.add(part_name);
}
}
void ReplicatedMergeTreeQueue::updateTimesInZooKeeper(
zkutil::ZooKeeperPtr zookeeper,
@ -629,7 +647,7 @@ void ReplicatedMergeTreeQueue::updateMutations(zkutil::ZooKeeperPtr zookeeper, C
for (const ReplicatedMergeTreeMutationEntryPtr & entry : new_mutations)
{
auto & mutation = mutations_by_znode.emplace(entry->znode_name, MutationStatus(entry))
auto & mutation = mutations_by_znode.emplace(entry->znode_name, MutationStatus(entry, format_version))
.first->second;
for (const auto & pair : entry->block_numbers)
@ -640,7 +658,9 @@ void ReplicatedMergeTreeQueue::updateMutations(zkutil::ZooKeeperPtr zookeeper, C
}
/// Initialize `mutation.parts_to_do`. First we need to mutate all parts in `current_parts`.
mutation.parts_to_do += getPartNamesToMutate(*entry, current_parts).size();
Strings current_parts_to_mutate = getPartNamesToMutate(*entry, current_parts);
for (const String & current_part_to_mutate : current_parts_to_mutate)
mutation.parts_to_do.add(current_part_to_mutate);
/// And next we would need to mutate all parts with getDataVersion() greater than
/// mutation block number that would appear as a result of executing the queue.
@ -651,11 +671,11 @@ void ReplicatedMergeTreeQueue::updateMutations(zkutil::ZooKeeperPtr zookeeper, C
auto part_info = MergeTreePartInfo::fromPartName(produced_part_name, format_version);
auto it = entry->block_numbers.find(part_info.partition_id);
if (it != entry->block_numbers.end() && it->second > part_info.getDataVersion())
++mutation.parts_to_do;
mutation.parts_to_do.add(produced_part_name);
}
}
if (mutation.parts_to_do == 0)
if (mutation.parts_to_do.size() == 0)
some_mutations_are_probably_done = true;
}
}
@ -1277,8 +1297,14 @@ bool ReplicatedMergeTreeQueue::tryFinalizeMutations(zkutil::ZooKeeperPtr zookeep
{
LOG_TRACE(log, "Marking mutation " << znode << " done because it is <= mutation_pointer (" << mutation_pointer << ")");
mutation.is_done = true;
if (mutation.parts_to_do.size() != 0)
{
LOG_INFO(log, "Seems like we jumped over mutation " << znode << " when downloaded part with bigger mutation number."
<< " It's OK, tasks for rest parts will be skipped, but probably a lot of mutations were executed concurrently on different replicas.");
mutation.parts_to_do.clear();
}
}
else if (mutation.parts_to_do == 0)
else if (mutation.parts_to_do.size() == 0)
{
LOG_TRACE(log, "Will check if mutation " << mutation.entry->znode_name << " is done");
candidates.push_back(mutation.entry);
@ -1417,7 +1443,7 @@ std::vector<MergeTreeMutationStatus> ReplicatedMergeTreeQueue::getMutationsStatu
{
const MutationStatus & status = pair.second;
const ReplicatedMergeTreeMutationEntry & entry = *status.entry;
const Names parts_to_mutate = getPartNamesToMutate(entry, current_parts);
Names parts_to_mutate = status.parts_to_do.getParts();
for (const MutationCommand & command : entry.commands)
{

View File

@ -97,18 +97,21 @@ private:
struct MutationStatus
{
MutationStatus(const ReplicatedMergeTreeMutationEntryPtr & entry_)
MutationStatus(const ReplicatedMergeTreeMutationEntryPtr & entry_, MergeTreeDataFormatVersion format_version_)
: entry(entry_)
, parts_to_do(format_version_)
{
}
ReplicatedMergeTreeMutationEntryPtr entry;
/// A number of parts that should be mutated/merged or otherwise moved to Obsolete state for this mutation to complete.
Int64 parts_to_do = 0;
/// Parts we have to mutate to complete mutation. We use ActiveDataPartSet structure
/// to be able to manage covering and covered parts.
ActiveDataPartSet parts_to_do;
/// Note that is_done is not equivalent to parts_to_do == 0
/// (even if parts_to_do == 0 some relevant parts can still commit in the future).
/// Note that is_done is not equivalent to parts_to_do.size() == 0
/// (even if parts_to_do.size() == 0 some relevant parts can still commit in the future).
/// Also we can jump over mutation when we dowload mutated part from other replica.
bool is_done = false;
String latest_failed_part;
@ -191,9 +194,14 @@ private:
std::optional<time_t> & max_processed_insert_time_changed,
std::unique_lock<std::mutex> & state_lock);
/// If the new part appears (add == true) or becomes obsolete (add == false), update parts_to_do of all affected mutations.
/// Notifies storage.mutations_finalizing_task if some mutations are probably finished.
void updateMutationsPartsToDo(const String & part_name, bool add);
/// Add part for mutations with block_number > part.getDataVersion()
void addPartToMutations(const String & part_name);
/// Remove part from mutations which were assigned to mutate it
/// with block_number > part.getDataVersion()
/// and block_number == part.getDataVersion()
/// ^ (this may happen if we downloaded mutated part from other replica)
void removePartFromMutations(const String & part_name);
/// Update the insertion times in ZooKeeper.
void updateTimesInZooKeeper(zkutil::ZooKeeperPtr zookeeper,

View File

@ -382,4 +382,133 @@
</attribute>
</structure>
</dictionary>
</dictionaries>
<dictionary>
<name>one_cell_cache_ints</name>
<source>
<clickhouse>
<host>localhost</host>
<port>9000</port>
<user>default</user>
<password></password>
<db>test_01054</db>
<table>ints</table>
</clickhouse>
</source>
<lifetime>0</lifetime>
<layout>
<cache><size_in_cells>1</size_in_cells></cache>
</layout>
<structure>
<id>
<name>key</name>
</id>
<attribute>
<name>i8</name>
<type>Int8</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>i16</name>
<type>Int16</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>i32</name>
<type>Int32</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>i64</name>
<type>Int64</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>u8</name>
<type>UInt8</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>u16</name>
<type>UInt16</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>u32</name>
<type>UInt32</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>u64</name>
<type>UInt64</type>
<null_value>0</null_value>
</attribute>
</structure>
</dictionary>
<dictionary>
<name>one_cell_cache_ints_overflow</name>
<source>
<clickhouse>
<host>localhost</host>
<port>9000</port>
<user>default</user>
<password></password>
<db>test_01054_overflow</db>
<table>ints</table>
</clickhouse>
</source>
<lifetime>0</lifetime>
<layout>
<cache><size_in_cells>1</size_in_cells></cache>
</layout>
<structure>
<id>
<name>key</name>
</id>
<attribute>
<name>i8</name>
<type>Int8</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>i16</name>
<type>Int16</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>i32</name>
<type>Int32</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>i64</name>
<type>Int64</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>u8</name>
<type>UInt8</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>u16</name>
<type>UInt16</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>u32</name>
<type>UInt32</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>u64</name>
<type>UInt64</type>
<null_value>0</null_value>
</attribute>
</structure>
</dictionary>
</dictionaries>

View File

@ -643,6 +643,22 @@ class ClickHouseInstance:
return urllib.urlopen(url, data).read()
def kill_clickhouse(self, stop_start_wait_sec=5):
pid = self.get_process_pid("clickhouse")
if not pid:
raise Exception("No clickhouse found")
self.exec_in_container(["bash", "-c", "kill -9 {}".format(pid)], user='root')
time.sleep(stop_start_wait_sec)
def restore_clickhouse(self, retries=100):
pid = self.get_process_pid("clickhouse")
if pid:
raise Exception("ClickHouse has already started")
self.exec_in_container(["bash", "-c", "{} --daemon".format(CLICKHOUSE_START_COMMAND)], user=str(os.getuid()))
from helpers.test_tools import assert_eq_with_retry
# wait start
assert_eq_with_retry(self, "select 1", "1", retry_count=retries)
def restart_clickhouse(self, stop_start_wait_sec=5, kill=False):
if not self.stay_alive:
raise Exception("clickhouse can be restarted only with stay_alive=True instance")
@ -949,3 +965,14 @@ class ClickHouseInstance:
def destroy_dir(self):
if p.exists(self.path):
shutil.rmtree(self.path)
class ClickHouseKiller(object):
def __init__(self, clickhouse_node):
self.clickhouse_node = clickhouse_node
def __enter__(self):
self.clickhouse_node.kill_clickhouse()
def __exit__(self, exc_type, exc_val, exc_tb):
self.clickhouse_node.restore_clickhouse()

View File

@ -90,7 +90,7 @@ class PartitionManager:
self.heal_all()
class PartitionManagerDisbaler:
class PartitionManagerDisabler:
def __init__(self, manager):
self.manager = manager
self.rules = self.manager.pop_rules()

View File

@ -293,13 +293,16 @@ class DictionaryStructure(object):
class Dictionary(object):
def __init__(self, name, structure, source, config_path, table_name, fields):
def __init__(self, name, structure, source, config_path,
table_name, fields, min_lifetime=3, max_lifetime=5):
self.name = name
self.structure = copy.deepcopy(structure)
self.source = copy.deepcopy(source)
self.config_path = config_path
self.table_name = table_name
self.fields = fields
self.min_lifetime = min_lifetime
self.max_lifetime = max_lifetime
def generate_config(self):
with open(self.config_path, 'w') as result:
@ -307,8 +310,8 @@ class Dictionary(object):
<yandex>
<dictionary>
<lifetime>
<min>3</min>
<max>5</max>
<min>{min_lifetime}</min>
<max>{max_lifetime}</max>
</lifetime>
<name>{name}</name>
{structure}
@ -318,6 +321,8 @@ class Dictionary(object):
</dictionary>
</yandex>
'''.format(
min_lifetime=self.min_lifetime,
max_lifetime=self.max_lifetime,
name=self.name,
structure=self.structure.get_structure_str(),
source=self.source.get_source_str(self.table_name),

View File

@ -72,34 +72,34 @@ FIELDS = {
VALUES = {
"simple": [
[1, 22, 333, 4444, 55555, -6, -77,
-888, -999, '550e8400-e29b-41d4-a716-446655440003',
'1973-06-28', '1985-02-28 23:43:25', 'hello', 22.543, 3332154213.4, 0],
-888, -999, '550e8400-e29b-41d4-a716-446655440003',
'1973-06-28', '1985-02-28 23:43:25', 'hello', 22.543, 3332154213.4, 0],
[2, 3, 4, 5, 6, -7, -8,
-9, -10, '550e8400-e29b-41d4-a716-446655440002',
'1978-06-28', '1986-02-28 23:42:25', 'hello', 21.543, 3222154213.4, 1]
-9, -10, '550e8400-e29b-41d4-a716-446655440002',
'1978-06-28', '1986-02-28 23:42:25', 'hello', 21.543, 3222154213.4, 1]
],
"complex": [
[1, 'world', 22, 333, 4444, 55555, -6,
-77, -888, -999, '550e8400-e29b-41d4-a716-446655440003',
'1973-06-28', '1985-02-28 23:43:25',
'hello', 22.543, 3332154213.4],
-77, -888, -999, '550e8400-e29b-41d4-a716-446655440003',
'1973-06-28', '1985-02-28 23:43:25',
'hello', 22.543, 3332154213.4],
[2, 'qwerty2', 52, 2345, 6544, 9191991, -2,
-717, -81818, -92929, '550e8400-e29b-41d4-a716-446655440007',
'1975-09-28', '2000-02-28 23:33:24',
'my', 255.543, 3332221.44]
-717, -81818, -92929, '550e8400-e29b-41d4-a716-446655440007',
'1975-09-28', '2000-02-28 23:33:24',
'my', 255.543, 3332221.44]
],
"ranged": [
[1, '2019-02-10', '2019-02-01', '2019-02-28',
22, 333, 4444, 55555, -6, -77, -888, -999,
'550e8400-e29b-41d4-a716-446655440003',
'1973-06-28', '1985-02-28 23:43:25', 'hello',
22.543, 3332154213.4],
22, 333, 4444, 55555, -6, -77, -888, -999,
'550e8400-e29b-41d4-a716-446655440003',
'1973-06-28', '1985-02-28 23:43:25', 'hello',
22.543, 3332154213.4],
[2, '2019-04-10', '2019-04-01', '2019-04-28',
11, 3223, 41444, 52515, -65, -747, -8388, -9099,
'550e8400-e29b-41d4-a716-446655440004',
'1973-06-29', '2002-02-28 23:23:25', '!!!!',
32.543, 3332543.4]
11, 3223, 41444, 52515, -65, -747, -8388, -9099,
'550e8400-e29b-41d4-a716-446655440004',
'1973-06-29', '2002-02-28 23:23:25', '!!!!',
32.543, 3332543.4]
]
}

View File

@ -0,0 +1,30 @@
<?xml version="1.0"?>
<yandex>
<logger>
<level>trace</level>
<log>/var/log/clickhouse-server/clickhouse-server.log</log>
<errorlog>/var/log/clickhouse-server/clickhouse-server.err.log</errorlog>
<size>1000M</size>
<count>10</count>
</logger>
<tcp_port>9000</tcp_port>
<listen_host>127.0.0.1</listen_host>
<openSSL>
<client>
<cacheSessions>true</cacheSessions>
<verificationMode>none</verificationMode>
<invalidCertificateHandler>
<name>AcceptCertificateHandler</name>
</invalidCertificateHandler>
</client>
</openSSL>
<max_concurrent_queries>500</max_concurrent_queries>
<mark_cache_size>5368709120</mark_cache_size>
<path>./clickhouse/</path>
<users_config>users.xml</users_config>
<dictionaries_config>/etc/clickhouse-server/config.d/*.xml</dictionaries_config>
</yandex>

View File

@ -0,0 +1,72 @@
<yandex>
<dictionary>
<name>anime_dict</name>
<source>
<clickhouse>
<host>dictionary_node</host>
<port>9000</port>
<user>default</user>
<password></password>
<db>test</db>
<table>ints</table>
</clickhouse>
</source>
<lifetime>
<max>2</max>
<min>1</min>
</lifetime>
<layout>
<cache>
<size_in_cells>10000</size_in_cells>
<max_update_queue_size>10000</max_update_queue_size>
<allow_read_expired_keys>1</allow_read_expired_keys>
<update_queue_push_timeout_milliseconds>10</update_queue_push_timeout_milliseconds>
</cache>
</layout>
<structure>
<id>
<name>key</name>
</id>
<attribute>
<name>i8</name>
<type>Int8</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>i16</name>
<type>Int16</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>i32</name>
<type>Int32</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>i64</name>
<type>Int64</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>u8</name>
<type>UInt8</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>u16</name>
<type>UInt16</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>u32</name>
<type>UInt32</type>
<null_value>0</null_value>
</attribute>
<attribute>
<name>u64</name>
<type>UInt64</type>
<null_value>0</null_value>
</attribute>
</structure>
</dictionary>
</yandex>

View File

@ -0,0 +1,23 @@
<?xml version="1.0"?>
<yandex>
<profiles>
<default>
</default>
</profiles>
<users>
<default>
<password></password>
<networks incl="networks" replace="replace">
<ip>::/0</ip>
</networks>
<profile>default</profile>
<quota>default</quota>
</default>
</users>
<quotas>
<default>
</default>
</quotas>
</yandex>

View File

@ -0,0 +1,64 @@
from __future__ import print_function
import pytest
import time
import os
from contextlib import contextmanager
from helpers.cluster import ClickHouseCluster
from helpers.cluster import ClickHouseKiller
from helpers.network import PartitionManager
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
cluster = ClickHouseCluster(__file__, base_configs_dir=os.path.join(SCRIPT_DIR, 'configs'))
dictionary_node = cluster.add_instance('dictionary_node', stay_alive=True)
main_node = cluster.add_instance('main_node', main_configs=['configs/dictionaries/cache_ints_dictionary.xml'])
@pytest.fixture(scope="module")
def started_cluster():
try:
cluster.start()
dictionary_node.query("create database if not exists test;")
dictionary_node.query("drop table if exists test.ints;")
dictionary_node.query("create table test.ints "
"(key UInt64, "
"i8 Int8, i16 Int16, i32 Int32, i64 Int64, "
"u8 UInt8, u16 UInt16, u32 UInt32, u64 UInt64) "
"Engine = Memory;")
dictionary_node.query("insert into test.ints values (7, 7, 7, 7, 7, 7, 7, 7, 7);")
dictionary_node.query("insert into test.ints values (5, 5, 5, 5, 5, 5, 5, 5, 5);")
yield cluster
finally:
cluster.shutdown()
# @pytest.mark.skip(reason="debugging")
def test_default_reading(started_cluster):
assert None != dictionary_node.get_process_pid("clickhouse"), "ClickHouse must be alive"
# Key 0 is not in dictionary, so default value will be returned
def test_helper():
assert '42' == main_node.query("select dictGetOrDefault('anime_dict', 'i8', toUInt64(13), toInt8(42));").rstrip()
assert '42' == main_node.query("select dictGetOrDefault('anime_dict', 'i16', toUInt64(13), toInt16(42));").rstrip()
assert '42' == main_node.query("select dictGetOrDefault('anime_dict', 'i32', toUInt64(13), toInt32(42));").rstrip()
assert '42' == main_node.query("select dictGetOrDefault('anime_dict', 'i64', toUInt64(13), toInt64(42));").rstrip()
assert '42' == main_node.query("select dictGetOrDefault('anime_dict', 'u8', toUInt64(13), toUInt8(42));").rstrip()
assert '42' == main_node.query("select dictGetOrDefault('anime_dict', 'u16', toUInt64(13), toUInt16(42));").rstrip()
assert '42' == main_node.query("select dictGetOrDefault('anime_dict', 'u32', toUInt64(13), toUInt32(42));").rstrip()
assert '42' == main_node.query("select dictGetOrDefault('anime_dict', 'u64', toUInt64(13), toUInt64(42));").rstrip()
test_helper()
with PartitionManager() as pm, ClickHouseKiller(dictionary_node):
assert None == dictionary_node.get_process_pid("clickhouse"), "CLickHouse must be alive"
# Remove connection between main_node and dictionary for sure
pm.heal_all()
pm.partition_instances(main_node, dictionary_node)
# Dictionary max lifetime is 2 seconds.
time.sleep(3)
test_helper()

View File

@ -0,0 +1,63 @@
from __future__ import print_function
import pytest
import time
import os
from contextlib import contextmanager
from helpers.cluster import ClickHouseCluster
from helpers.cluster import ClickHouseKiller
from helpers.network import PartitionManager
from helpers.network import PartitionManagerDisabler
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
cluster = ClickHouseCluster(__file__, base_configs_dir=os.path.join(SCRIPT_DIR, 'configs'))
dictionary_node = cluster.add_instance('dictionary_node', stay_alive=True)
main_node = cluster.add_instance('main_node', main_configs=['configs/dictionaries/cache_ints_dictionary.xml'])
@pytest.fixture(scope="module")
def started_cluster():
try:
cluster.start()
dictionary_node.query("create database if not exists test;")
dictionary_node.query("drop table if exists test.ints;")
dictionary_node.query("create table test.ints "
"(key UInt64, "
"i8 Int8, i16 Int16, i32 Int32, i64 Int64, "
"u8 UInt8, u16 UInt16, u32 UInt32, u64 UInt64) "
"Engine = Memory;")
dictionary_node.query("insert into test.ints values (7, 7, 7, 7, 7, 7, 7, 7, 7);")
dictionary_node.query("insert into test.ints values (5, 5, 5, 5, 5, 5, 5, 5, 5);")
yield cluster
finally:
cluster.shutdown()
# @pytest.mark.skip(reason="debugging")
def test_simple_dict_get(started_cluster):
assert None != dictionary_node.get_process_pid("clickhouse"), "ClickHouse must be alive"
def test_helper():
assert '7' == main_node.query("select dictGet('anime_dict', 'i8', toUInt64(7));").rstrip(), "Wrong answer."
assert '7' == main_node.query("select dictGet('anime_dict', 'i16', toUInt64(7));").rstrip(), "Wrong answer."
assert '7' == main_node.query("select dictGet('anime_dict', 'i32', toUInt64(7));").rstrip(), "Wrong answer."
assert '7' == main_node.query("select dictGet('anime_dict', 'i64', toUInt64(7));").rstrip(), "Wrong answer."
assert '7' == main_node.query("select dictGet('anime_dict', 'u8', toUInt64(7));").rstrip(), "Wrong answer."
assert '7' == main_node.query("select dictGet('anime_dict', 'u16', toUInt64(7));").rstrip(), "Wrong answer."
assert '7' == main_node.query("select dictGet('anime_dict', 'u32', toUInt64(7));").rstrip(), "Wrong answer."
assert '7' == main_node.query("select dictGet('anime_dict', 'u64', toUInt64(7));").rstrip(), "Wrong answer."
test_helper()
with PartitionManager() as pm, ClickHouseKiller(dictionary_node):
assert None == dictionary_node.get_process_pid("clickhouse")
# Remove connection between main_node and dictionary for sure
pm.heal_all()
pm.partition_instances(main_node, dictionary_node)
# Dictionary max lifetime is 2 seconds.
time.sleep(3)
test_helper()

View File

@ -0,0 +1,60 @@
from __future__ import print_function
import pytest
import time
import os
from contextlib import contextmanager
from helpers.cluster import ClickHouseCluster
from helpers.cluster import ClickHouseKiller
from helpers.network import PartitionManager
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
cluster = ClickHouseCluster(__file__, base_configs_dir=os.path.join(SCRIPT_DIR, 'configs'))
dictionary_node = cluster.add_instance('dictionary_node', stay_alive=True)
main_node = cluster.add_instance('main_node', main_configs=['configs/dictionaries/cache_ints_dictionary.xml'])
@pytest.fixture(scope="module")
def started_cluster():
try:
cluster.start()
dictionary_node.query("create database if not exists test;")
dictionary_node.query("drop table if exists test.ints;")
dictionary_node.query("create table test.ints "
"(key UInt64, "
"i8 Int8, i16 Int16, i32 Int32, i64 Int64, "
"u8 UInt8, u16 UInt16, u32 UInt32, u64 UInt64) "
"Engine = Memory;")
dictionary_node.query("insert into test.ints values (7, 7, 7, 7, 7, 7, 7, 7, 7);")
dictionary_node.query("insert into test.ints values (5, 5, 5, 5, 5, 5, 5, 5, 5);")
yield cluster
finally:
cluster.shutdown()
# @pytest.mark.skip(reason="debugging")
def test_simple_dict_get_or_default(started_cluster):
assert None != dictionary_node.get_process_pid("clickhouse"), "ClickHouse must be alive"
def test_helper():
assert '5' == main_node.query("select dictGetOrDefault('anime_dict', 'i8', toUInt64(5), toInt8(42));").rstrip()
assert '5' == main_node.query("select dictGetOrDefault('anime_dict', 'i16', toUInt64(5), toInt16(42));").rstrip()
assert '5' == main_node.query("select dictGetOrDefault('anime_dict', 'i32', toUInt64(5), toInt32(42));").rstrip()
assert '5' == main_node.query("select dictGetOrDefault('anime_dict', 'i64', toUInt64(5), toInt64(42));").rstrip()
assert '5' == main_node.query("select dictGetOrDefault('anime_dict', 'u8', toUInt64(5), toUInt8(42));").rstrip()
assert '5' == main_node.query("select dictGetOrDefault('anime_dict', 'u16', toUInt64(5), toUInt16(42));").rstrip()
assert '5' == main_node.query("select dictGetOrDefault('anime_dict', 'u32', toUInt64(5), toUInt32(42));").rstrip()
assert '5' == main_node.query("select dictGetOrDefault('anime_dict', 'u64', toUInt64(5), toUInt64(42));").rstrip()
test_helper()
with PartitionManager() as pm, ClickHouseKiller(dictionary_node):
assert None == dictionary_node.get_process_pid("clickhouse")
# Remove connection between main_node and dictionary for sure
pm.partition_instances(main_node, dictionary_node)
# Dictionary max lifetime is 2 seconds.
time.sleep(3)
test_helper()

View File

@ -0,0 +1,16 @@
<?xml version="1.0"?>
<yandex>
<profiles>
<default>
</default>
</profiles>
<users>
<default>
<password></password>
<networks incl="networks" replace="replace">
<ip>::/0</ip>
</networks>
<profile>default</profile>
</default>
</users>
</yandex>

View File

@ -0,0 +1,49 @@
import pytest
from helpers.cluster import ClickHouseCluster
import re
cluster = ClickHouseCluster(__file__)
instance = cluster.add_instance('instance', config_dir="configs")
@pytest.fixture(scope="module", autouse=True)
def started_cluster():
try:
cluster.start()
instance.query("CREATE TABLE test_table(x UInt32) ENGINE = MergeTree ORDER BY tuple()")
instance.query("INSERT INTO test_table SELECT number FROM numbers(3)")
instance.query("CREATE USER A PROFILE 'default'")
instance.query("CREATE USER B PROFILE 'default'")
yield cluster
finally:
cluster.shutdown()
def test_login():
assert instance.query("SELECT 1", user='A') == "1\n"
assert instance.query("SELECT 1", user='B') == "1\n"
def test_grant_and_revoke():
assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='A')
instance.query('GRANT SELECT ON test_table TO A')
assert instance.query("SELECT * FROM test_table", user='A') == "0\n1\n2\n"
instance.query('REVOKE SELECT ON test_table FROM A')
assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='A')
def test_grant_option():
instance.query('GRANT SELECT ON test_table TO A')
assert instance.query("SELECT * FROM test_table", user='A') == "0\n1\n2\n"
assert "Not enough privileges" in instance.query_and_get_error("GRANT SELECT ON test_table TO B", user='A')
instance.query('GRANT SELECT ON test_table TO A WITH GRANT OPTION')
instance.query("GRANT SELECT ON test_table TO B", user='A')
assert instance.query("SELECT * FROM test_table", user='B') == "0\n1\n2\n"
instance.query('REVOKE SELECT ON test_table FROM A, B')

View File

@ -1,2 +1,6 @@
# https://github.com/google/oss-fuzz/issues/1099
fun:__gxx_personality_*
# We apply std::tolower to uninitialized padding, but don't use the result, so
# it is OK. Reproduce with "select ngramDistanceCaseInsensitive(materialize(''), '')"
fun:tolower

View File

@ -0,0 +1,74 @@
#!/usr/bin/env bash
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
. $CURDIR/../shell_config.sh
$CLICKHOUSE_CLIENT --query="create database if not exists test_01054;"
$CLICKHOUSE_CLIENT --query="drop table if exists test_01054.ints;"
$CLICKHOUSE_CLIENT --query="create table test_01054.ints
(key UInt64, i8 Int8, i16 Int16, i32 Int32, i64 Int64, u8 UInt8, u16 UInt16, u32 UInt32, u64 UInt64)
Engine = Memory;"
$CLICKHOUSE_CLIENT --query="insert into test_01054.ints values (1, 1, 1, 1, 1, 1, 1, 1, 1);"
$CLICKHOUSE_CLIENT --query="insert into test_01054.ints values (2, 2, 2, 2, 2, 2, 2, 2, 2);"
$CLICKHOUSE_CLIENT --query="insert into test_01054.ints values (3, 3, 3, 3, 3, 3, 3, 3, 3);"
function thread1()
{
for attempt_thread1 in {1..100}
do
RAND_NUMBER_THREAD1=$($CLICKHOUSE_CLIENT --query="SELECT rand() % 100;")
$CLICKHOUSE_CLIENT --query="select dictGet('one_cell_cache_ints', 'i8', toUInt64($RAND_NUMBER_THREAD1));"
done
}
function thread2()
{
for attempt_thread2 in {1..100}
do
RAND_NUMBER_THREAD2=$($CLICKHOUSE_CLIENT --query="SELECT rand() % 100;")
$CLICKHOUSE_CLIENT --query="select dictGet('one_cell_cache_ints', 'i8', toUInt64($RAND_NUMBER_THREAD2));"
done
}
function thread3()
{
for attempt_thread3 in {1..100}
do
RAND_NUMBER_THREAD3=$($CLICKHOUSE_CLIENT --query="SELECT rand() % 100;")
$CLICKHOUSE_CLIENT --query="select dictGet('one_cell_cache_ints', 'i8', toUInt64($RAND_NUMBER_THREAD3));"
done
}
function thread4()
{
for attempt_thread4 in {1..100}
do
RAND_NUMBER_THREAD4=$($CLICKHOUSE_CLIENT --query="SELECT rand() % 100;")
$CLICKHOUSE_CLIENT --query="select dictGet('one_cell_cache_ints', 'i8', toUInt64($RAND_NUMBER_THREAD4));"
done
}
export -f thread1;
export -f thread2;
export -f thread3;
export -f thread4;
TIMEOUT=10
# shellcheck disable=SC2188
timeout $TIMEOUT bash -c thread1 > /dev/null 2>&1 &
timeout $TIMEOUT bash -c thread2 > /dev/null 2>&1 &
timeout $TIMEOUT bash -c thread3 > /dev/null 2>&1 &
timeout $TIMEOUT bash -c thread4 > /dev/null 2>&1 &
wait
echo OK
$CLICKHOUSE_CLIENT --query "DROP TABLE if exists test_01054.ints"

View File

@ -0,0 +1,2 @@
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]

View File

@ -0,0 +1,56 @@
create database if not exists test_01054_overflow;
drop table if exists test_01054_overflow.ints;
create table test_01054_overflow.ints (key UInt64, i8 Int8, i16 Int16, i32 Int32, i64 Int64, u8 UInt8, u16 UInt16, u32 UInt32, u64 UInt64) Engine = Memory;
insert into test_01054_overflow.ints values (1, 1, 1, 1, 1, 1, 1, 1, 1);
insert into test_01054_overflow.ints values (2, 2, 2, 2, 2, 2, 2, 2, 2);
insert into test_01054_overflow.ints values (3, 3, 3, 3, 3, 3, 3, 3, 3);
insert into test_01054_overflow.ints values (4, 4, 4, 4, 4, 4, 4, 4, 4);
insert into test_01054_overflow.ints values (5, 5, 5, 5, 5, 5, 5, 5, 5);
insert into test_01054_overflow.ints values (6, 6, 6, 6, 6, 6, 6, 6, 6);
insert into test_01054_overflow.ints values (7, 7, 7, 7, 7, 7, 7, 7, 7);
insert into test_01054_overflow.ints values (8, 8, 8, 8, 8, 8, 8, 8, 8);
insert into test_01054_overflow.ints values (9, 9, 9, 9, 9, 9, 9, 9, 9);
insert into test_01054_overflow.ints values (10, 10, 10, 10, 10, 10, 10, 10, 10);
insert into test_01054_overflow.ints values (11, 11, 11, 11, 11, 11, 11, 11, 11);
insert into test_01054_overflow.ints values (12, 12, 12, 12, 12, 12, 12, 12, 12);
insert into test_01054_overflow.ints values (13, 13, 13, 13, 13, 13, 13, 13, 13);
insert into test_01054_overflow.ints values (14, 14, 14, 14, 14, 14, 14, 14, 14);
insert into test_01054_overflow.ints values (15, 15, 15, 15, 15, 15, 15, 15, 15);
insert into test_01054_overflow.ints values (16, 16, 16, 16, 16, 16, 16, 16, 16);
insert into test_01054_overflow.ints values (17, 17, 17, 17, 17, 17, 17, 17, 17);
insert into test_01054_overflow.ints values (18, 18, 18, 18, 18, 18, 18, 18, 18);
insert into test_01054_overflow.ints values (19, 19, 19, 19, 19, 19, 19, 19, 19);
insert into test_01054_overflow.ints values (20, 20, 20, 20, 20, 20, 20, 20, 20);
select
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(1)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(2)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(3)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(4)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(5)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(6)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(7)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(8)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(9)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(10)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(11)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(12)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(13)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(14)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(15)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(16)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(17)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(18)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(19)),
dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(20));
SELECT arrayMap(x -> dictGet('one_cell_cache_ints_overflow', 'i8', toUInt64(x)), array)
FROM
(
SELECT [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] AS array
);
DROP TABLE if exists test_01054.ints;

View File

@ -0,0 +1,11 @@
CREATE USER test_user_01073
A
B
GRANT DELETE, INSERT ON *.* TO test_user_01073
GRANT SELECT ON db1.* TO test_user_01073
GRANT SELECT ON db2.table TO test_user_01073
GRANT SELECT(col1) ON db3.table TO test_user_01073
GRANT SELECT(col1, col2) ON db4.table TO test_user_01073
C
GRANT DELETE ON *.* TO test_user_01073
GRANT SELECT(col1) ON db4.table TO test_user_01073

View File

@ -0,0 +1,28 @@
DROP USER IF EXISTS test_user_01073;
CREATE USER test_user_01073;
SHOW CREATE USER test_user_01073;
SELECT 'A';
SHOW GRANTS FOR test_user_01073;
GRANT SELECT ON db1.* TO test_user_01073;
GRANT SELECT ON db2.table TO test_user_01073;
GRANT SELECT(col1) ON db3.table TO test_user_01073;
GRANT SELECT(col1, col2) ON db4.table TO test_user_01073;
GRANT INSERT ON *.* TO test_user_01073;
GRANT DELETE ON *.* TO test_user_01073;
SELECT 'B';
SHOW GRANTS FOR test_user_01073;
REVOKE SELECT ON db1.* FROM test_user_01073;
REVOKE SELECT ON db2.table FROM test_user_01073;
REVOKE SELECT ON db3.table FROM test_user_01073;
REVOKE SELECT(col2) ON db4.table FROM test_user_01073;
REVOKE INSERT ON *.* FROM test_user_01073;
SELECT 'C';
SHOW GRANTS FOR test_user_01073;
DROP USER test_user_01073;

View File

@ -0,0 +1,5 @@
A
GRANT SELECT ON *.* TO test_user_01074
B
GRANT SELECT ON *.* TO test_user_01074
REVOKE SELECT ON db.* FROM test_user_01074

View File

@ -0,0 +1,15 @@
DROP USER IF EXISTS test_user_01074;
CREATE USER test_user_01074;
SELECT 'A';
SET partial_revokes=0;
GRANT SELECT ON *.* TO test_user_01074;
REVOKE SELECT ON db.* FROM test_user_01074;
SHOW GRANTS FOR test_user_01074;
SELECT 'B';
SET partial_revokes=1;
REVOKE SELECT ON db.* FROM test_user_01074;
SHOW GRANTS FOR test_user_01074;
DROP USER test_user_01074;

View File

@ -0,0 +1,14 @@
CREATE USER test_user_01075
CREATE USER test_user_01075
CREATE USER test_user_01075 HOST NONE
CREATE USER test_user_01075 HOST LOCAL
CREATE USER test_user_01075 HOST IP \'192.168.23.15\'
CREATE USER test_user_01075 HOST IP \'2001:db8:11a3:9d7:1f34:8a2e:7a0:765d\'
CREATE USER test_user_01075 HOST LOCAL, IP \'2001:db8:11a3:9d7:1f34:8a2e:7a0:765d\'
CREATE USER test_user_01075 HOST LOCAL
CREATE USER test_user_01075 HOST NONE
CREATE USER test_user_01075 HOST LIKE \'@.somesite.com\'
CREATE USER test_user_01075 HOST NAME REGEXP \'.*.anothersite.com\'
CREATE USER `test_user_01075_x@localhost` HOST LOCAL
CREATE USER test_user_01075_x
CREATE USER `test_user_01075_x@192.168.23.15` HOST LIKE \'192.168.23.15\'

View File

@ -0,0 +1,47 @@
DROP USER IF EXISTS test_user_01075, test_user_01075_x, test_user_01075_x@localhost, test_user_01075_x@'192.168.23.15';
CREATE USER test_user_01075;
SHOW CREATE USER test_user_01075;
ALTER USER test_user_01075 HOST ANY;
SHOW CREATE USER test_user_01075;
ALTER USER test_user_01075 HOST NONE;
SHOW CREATE USER test_user_01075;
ALTER USER test_user_01075 HOST LOCAL;
SHOW CREATE USER test_user_01075;
ALTER USER test_user_01075 HOST IP '192.168.23.15';
SHOW CREATE USER test_user_01075;
ALTER USER test_user_01075 HOST IP '2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d';
SHOW CREATE USER test_user_01075;
ALTER USER test_user_01075 ADD HOST IP '127.0.0.1';
SHOW CREATE USER test_user_01075;
ALTER USER test_user_01075 REMOVE HOST IP '2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d';
SHOW CREATE USER test_user_01075;
ALTER USER test_user_01075 REMOVE HOST NAME 'localhost';
SHOW CREATE USER test_user_01075;
ALTER USER test_user_01075 HOST LIKE '@.somesite.com';
SHOW CREATE USER test_user_01075;
ALTER USER test_user_01075 HOST NAME REGEXP '.*\.anothersite\.com';
SHOW CREATE USER test_user_01075;
DROP USER test_user_01075;
CREATE USER test_user_01075_x@localhost;
SHOW CREATE USER test_user_01075_x@localhost;
ALTER USER test_user_01075_x@localhost RENAME TO test_user_01075_x@'@';
SHOW CREATE USER test_user_01075_x;
ALTER USER test_user_01075_x RENAME TO test_user_01075_x@'192.168.23.15';
SHOW CREATE USER 'test_user_01075_x@192.168.23.15';
DROP USER 'test_user_01075_x@192.168.23.15';

View File

@ -0,0 +1,10 @@
SET log_queries = 1;
SELECT 1 LIMIT 0;
SYSTEM FLUSH LOGS;
SELECT arrayJoin AS kv_key
FROM system.query_log
ARRAY JOIN ProfileEvents.Names AS arrayJoin
PREWHERE has(arrayMap(key -> key, ProfileEvents.Names), 'Query')
WHERE arrayJoin = 'Query'
LIMIT 0;

View File

@ -0,0 +1,17 @@
1725
1725
1725
1725
1725
Starting alters
Finishing alters
1
0
1
0
1
0
1
0
1
0

Some files were not shown because too many files have changed in this diff Show More