From 12d1d87d648b1f875f9cee143140a42ee4728c97 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Tue, 8 Oct 2019 16:44:44 +0300 Subject: [PATCH 1/5] Move authentication code to a separate class. --- dbms/CMakeLists.txt | 1 + dbms/programs/server/MySQLHandler.cpp | 6 +- dbms/src/Access/Authentication.cpp | 210 +++++++++++++++++++++++++ dbms/src/Access/Authentication.h | 67 ++++++++ dbms/src/Access/CMakeLists.txt | 0 dbms/src/CMakeLists.txt | 1 + dbms/src/Core/MySQLProtocol.h | 4 +- dbms/src/Interpreters/Users.cpp | 26 +-- dbms/src/Interpreters/Users.h | 7 +- dbms/src/Interpreters/UsersManager.cpp | 70 +-------- 10 files changed, 296 insertions(+), 96 deletions(-) create mode 100644 dbms/src/Access/Authentication.cpp create mode 100644 dbms/src/Access/Authentication.h create mode 100644 dbms/src/Access/CMakeLists.txt diff --git a/dbms/CMakeLists.txt b/dbms/CMakeLists.txt index 229639a8a7f..ec9ffc6e3dd 100644 --- a/dbms/CMakeLists.txt +++ b/dbms/CMakeLists.txt @@ -164,6 +164,7 @@ macro(add_object_library name common_path) endif () endmacro() +add_object_library(clickhouse_access src/Access) add_object_library(clickhouse_core src/Core) add_object_library(clickhouse_compression src/Compression) add_object_library(clickhouse_datastreams src/DataStreams) diff --git a/dbms/programs/server/MySQLHandler.cpp b/dbms/programs/server/MySQLHandler.cpp index 1b495552fbc..f7429ebf2a7 100644 --- a/dbms/programs/server/MySQLHandler.cpp +++ b/dbms/programs/server/MySQLHandler.cpp @@ -46,7 +46,7 @@ MySQLHandler::MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & so , connection_id(connection_id_) , public_key(public_key_) , private_key(private_key_) - , auth_plugin(new Authentication::Native41()) + , auth_plugin(new MySQLProtocol::Authentication::Native41()) { server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF; if (ssl_enabled) @@ -231,8 +231,8 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl { // For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used. auto user = connection_context.getUser(user_name); - if (user->password_double_sha1_hex.empty()) - auth_plugin = std::make_unique(public_key, private_key, log); + if (user->authentication.getType() != DB::Authentication::DOUBLE_SHA1_PASSWORD) + auth_plugin = std::make_unique(public_key, private_key, log); try { std::optional auth_response = auth_plugin_name == auth_plugin->getName() ? std::make_optional(initial_auth_response) : std::nullopt; diff --git a/dbms/src/Access/Authentication.cpp b/dbms/src/Access/Authentication.cpp new file mode 100644 index 00000000000..279cd4978f0 --- /dev/null +++ b/dbms/src/Access/Authentication.cpp @@ -0,0 +1,210 @@ +#include +#include +#include +#include +#include +#include +#include "config_core.h" +#if USE_SSL +# include +#endif + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int SUPPORT_IS_DISABLED; + extern const int REQUIRED_PASSWORD; + extern const int WRONG_PASSWORD; + extern const int BAD_ARGUMENTS; + extern const int LOGICAL_ERROR; +} + + +namespace +{ + using Digest = Authentication::Digest; + + Digest encodePlainText(const StringRef & text) + { + Digest digest; + digest.resize(text.size); + memcpy(digest.data(), text.data, text.size); + return digest; + } + + Digest encodeSHA256(const StringRef & text) + { +#if USE_SSL + Digest hash; + hash.resize(32); + SHA256_CTX ctx; + SHA256_Init(&ctx); + SHA256_Update(&ctx, reinterpret_cast(text.data), text.size); + SHA256_Final(hash.data(), &ctx); + return hash; +#else + UNUSED(text); + throw DB::Exception("SHA256 passwords support is disabled, because ClickHouse was built without SSL library", DB::ErrorCodes::SUPPORT_IS_DISABLED); +#endif + } + + Digest encodeSHA1(const StringRef & text) + { + Poco::SHA1Engine engine; + engine.update(text.data, text.size); + return engine.digest(); + } + + Digest encodeSHA1(const Digest & text) + { + return encodeSHA1(StringRef{reinterpret_cast(text.data()), text.size()}); + } + + Digest encodeDoubleSHA1(const StringRef & text) + { + return encodeSHA1(encodeSHA1(text)); + } +} + + +Authentication::Authentication(Authentication::Type type_) + : type(type_) +{ +} + + +void Authentication::setPassword(const String & password_) +{ + switch (type) + { + case NO_PASSWORD: + throw Exception("Cannot specify password for the 'NO_PASSWORD' authentication type", ErrorCodes::LOGICAL_ERROR); + + case PLAINTEXT_PASSWORD: + setPasswordHashBinary(encodePlainText(password_)); + return; + + case SHA256_PASSWORD: + setPasswordHashBinary(encodeSHA256(password_)); + return; + + case DOUBLE_SHA1_PASSWORD: + setPasswordHashBinary(encodeDoubleSHA1(password_)); + return; + } + throw Exception("Unknown authentication type: " + std::to_string(static_cast(type)), ErrorCodes::LOGICAL_ERROR); +} + + +String Authentication::getPassword() const +{ + if (type != PLAINTEXT_PASSWORD) + throw Exception("Cannot decode the password", ErrorCodes::LOGICAL_ERROR); + return String(reinterpret_cast(password_hash.data()), password_hash.size()); +} + + +void Authentication::setPasswordHashHex(const String & hash) +{ + Digest digest; + digest.resize(hash.size() / 2); + boost::algorithm::unhex(hash.begin(), hash.end(), digest.data()); + setPasswordHashBinary(digest); +} + + +String Authentication::getPasswordHashHex() const +{ + String hex; + hex.resize(password_hash.size() * 2); + boost::algorithm::hex(password_hash.begin(), password_hash.end(), hex.data()); + return hex; +} + + +void Authentication::setPasswordHashBinary(const Digest & hash) +{ + switch (type) + { + case NO_PASSWORD: + throw Exception("Cannot specify password for the 'NO_PASSWORD' authentication type", ErrorCodes::LOGICAL_ERROR); + + case PLAINTEXT_PASSWORD: + { + password_hash = hash; + return; + } + + case SHA256_PASSWORD: + { + if (hash.size() != 32) + throw Exception( + "Password hash for the 'SHA256_PASSWORD' authentication type has length " + std::to_string(hash.size()) + + " but must be exactly 32 bytes.", + ErrorCodes::BAD_ARGUMENTS); + password_hash = hash; + return; + } + + case DOUBLE_SHA1_PASSWORD: + { + if (hash.size() != 20) + throw Exception( + "Password hash for the 'DOUBLE_SHA1_PASSWORD' authentication type has length " + std::to_string(hash.size()) + + " but must be exactly 20 bytes.", + ErrorCodes::BAD_ARGUMENTS); + password_hash = hash; + return; + } + } + throw Exception("Unknown authentication type: " + std::to_string(static_cast(type)), ErrorCodes::LOGICAL_ERROR); +} + + +bool Authentication::isCorrectPassword(const String & password_) const +{ + switch (type) + { + case NO_PASSWORD: + return true; + + case PLAINTEXT_PASSWORD: + return password_ == StringRef{reinterpret_cast(password_hash.data()), password_hash.size()}; + + case SHA256_PASSWORD: + return encodeSHA256(password_) == password_hash; + + case DOUBLE_SHA1_PASSWORD: + { + auto first_sha1 = encodeSHA1(password_); + + /// If it was MySQL compatibility server, then first_sha1 already contains double SHA1. + if (first_sha1 == password_hash) + return true; + + return encodeSHA1(first_sha1) == password_hash; + } + } + throw Exception("Unknown authentication type: " + std::to_string(static_cast(type)), ErrorCodes::LOGICAL_ERROR); +} + + +void Authentication::checkPassword(const String & password_, const String & user_name) const +{ + if (isCorrectPassword(password_)) + return; + auto info_about_user_name = [&user_name]() { return user_name.empty() ? String() : " for user " + user_name; }; + if (password_.empty() && (type != NO_PASSWORD)) + throw Exception("Password required" + info_about_user_name(), ErrorCodes::REQUIRED_PASSWORD); + throw Exception("Wrong password" + info_about_user_name(), ErrorCodes::WRONG_PASSWORD); +} + + +bool operator ==(const Authentication & lhs, const Authentication & rhs) +{ + return (lhs.type == rhs.type) && (lhs.password_hash == rhs.password_hash); +} +} + diff --git a/dbms/src/Access/Authentication.h b/dbms/src/Access/Authentication.h new file mode 100644 index 00000000000..1f708af985b --- /dev/null +++ b/dbms/src/Access/Authentication.h @@ -0,0 +1,67 @@ +#pragma once + +#include + + +namespace DB +{ +/// Authentication type and encrypted password for checking when an user logins. +class Authentication +{ +public: + enum Type + { + /// User doesn't have to enter password. + NO_PASSWORD, + + /// Password is stored as is. + PLAINTEXT_PASSWORD, + + /// Password is encrypted in SHA256 hash. + SHA256_PASSWORD, + + /// SHA1(SHA1(password)). + /// This kind of hash is used by the `mysql_native_password` authentication plugin. + DOUBLE_SHA1_PASSWORD, + }; + + using Digest = std::vector; + + Authentication(Authentication::Type type = NO_PASSWORD); + Authentication(const Authentication & src) = default; + Authentication & operator =(const Authentication & src) = default; + Authentication(Authentication && src) = default; + Authentication & operator =(Authentication && src) = default; + + void setType(Type type_) { type = type_; } + Type getType() const { return type; } + + /// Sets the password. This function uses the authentication type set with setType() to encode the password. + void setPassword(const String & password); + + /// Returns the password. Allowed to use only for Type::PLAINTEXT_PASSWORD. + String getPassword() const; + + /// Sets the password as a string of hexadecimal digits. + void setPasswordHashHex(const String & hash); + String getPasswordHashHex() const; + + /// Sets the password in binary form. + void setPasswordHashBinary(const Digest & hash); + const Digest & getPasswordHashBinary() const { return password_hash; } + + /// Checks if the provided password is correct. Returns false if not. + bool isCorrectPassword(const String & password) const; + + /// Checks if the provided password is correct. Throws an exception if not. + /// `user_name` is only used for generating an error message if the password is incorrect. + void checkPassword(const String & password, const String & user_name = String()) const; + + friend bool operator ==(const Authentication & lhs, const Authentication & rhs); + friend bool operator !=(const Authentication & lhs, const Authentication & rhs) { return !(lhs == rhs); } + +private: + Type type = Type::NO_PASSWORD; + Digest password_hash; +}; +} diff --git a/dbms/src/Access/CMakeLists.txt b/dbms/src/Access/CMakeLists.txt new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dbms/src/CMakeLists.txt b/dbms/src/CMakeLists.txt index 84755f7f280..591fcd784b3 100644 --- a/dbms/src/CMakeLists.txt +++ b/dbms/src/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory (Access) add_subdirectory (Columns) add_subdirectory (Common) add_subdirectory (Core) diff --git a/dbms/src/Core/MySQLProtocol.h b/dbms/src/Core/MySQLProtocol.h index 2ac255cca34..2829e489f25 100644 --- a/dbms/src/Core/MySQLProtocol.h +++ b/dbms/src/Core/MySQLProtocol.h @@ -919,10 +919,10 @@ public: auto user = context.getUser(user_name); - if (user->password_double_sha1_hex.empty()) + if (user->authentication.getType() != DB::Authentication::DOUBLE_SHA1_PASSWORD) throw Exception("Cannot use " + getName() + " auth plugin for user " + user_name + " since its password isn't specified using double SHA1.", ErrorCodes::UNKNOWN_EXCEPTION); - Poco::SHA1Engine::Digest double_sha1_value = Poco::DigestEngine::digestFromHex(user->password_double_sha1_hex); + Poco::SHA1Engine::Digest double_sha1_value = user->authentication.getPasswordHashBinary(); assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE); Poco::SHA1Engine engine; diff --git a/dbms/src/Interpreters/Users.cpp b/dbms/src/Interpreters/Users.cpp index 86e37f0b729..a60defb5dda 100644 --- a/dbms/src/Interpreters/Users.cpp +++ b/dbms/src/Interpreters/Users.cpp @@ -3,20 +3,13 @@ #include #include #include -#include #include -#include #include #include -#include -#include -#include #include -#include #include #include #include -#include namespace DB @@ -27,8 +20,6 @@ namespace ErrorCodes extern const int DNS_ERROR; extern const int UNKNOWN_ADDRESS_PATTERN_TYPE; extern const int UNKNOWN_USER; - extern const int REQUIRED_PASSWORD; - extern const int WRONG_PASSWORD; extern const int IP_ADDRESS_NOT_ALLOWED; extern const int BAD_ARGUMENTS; } @@ -288,22 +279,21 @@ User::User(const String & name_, const String & config_elem, const Poco::Util::A throw Exception("Either 'password' or 'password_sha256_hex' or 'password_double_sha1_hex' must be specified for user " + name + ".", ErrorCodes::BAD_ARGUMENTS); if (has_password) - password = config.getString(config_elem + ".password"); + { + authentication.setType(Authentication::PLAINTEXT_PASSWORD); + authentication.setPassword(config.getString(config_elem + ".password")); + } if (has_password_sha256_hex) { - password_sha256_hex = Poco::toLower(config.getString(config_elem + ".password_sha256_hex")); - - if (password_sha256_hex.size() != 64) - throw Exception("password_sha256_hex for user " + name + " has length " + toString(password_sha256_hex.size()) + " but must be exactly 64 symbols.", ErrorCodes::BAD_ARGUMENTS); + authentication.setType(Authentication::SHA256_PASSWORD); + authentication.setPasswordHashHex(config.getString(config_elem + ".password_sha256_hex")); } if (has_password_double_sha1_hex) { - password_double_sha1_hex = Poco::toLower(config.getString(config_elem + ".password_double_sha1_hex")); - - if (password_double_sha1_hex.size() != 40) - throw Exception("password_double_sha1_hex for user " + name + " has length " + toString(password_double_sha1_hex.size()) + " but must be exactly 40 symbols.", ErrorCodes::BAD_ARGUMENTS); + authentication.setType(Authentication::DOUBLE_SHA1_PASSWORD); + authentication.setPasswordHashHex(config.getString(config_elem + ".password_double_sha1_hex")); } profile = config.getString(config_elem + ".profile"); diff --git a/dbms/src/Interpreters/Users.h b/dbms/src/Interpreters/Users.h index 090bc693e9a..0e1833427a8 100644 --- a/dbms/src/Interpreters/Users.h +++ b/dbms/src/Interpreters/Users.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -53,10 +54,8 @@ struct User { String name; - /// Required password. Could be stored in plaintext or in SHA256. - String password; - String password_sha256_hex; - String password_double_sha1_hex; + /// Required password. + Authentication authentication; String profile; String quota; diff --git a/dbms/src/Interpreters/UsersManager.cpp b/dbms/src/Interpreters/UsersManager.cpp index ee6293c3ee2..4d8cfbadc5b 100644 --- a/dbms/src/Interpreters/UsersManager.cpp +++ b/dbms/src/Interpreters/UsersManager.cpp @@ -1,18 +1,8 @@ #include -#include "config_core.h" #include -#include -#include -#include -#include #include -#include -#include #include -#if USE_SSL -# include -#endif namespace DB @@ -20,14 +10,8 @@ namespace DB namespace ErrorCodes { - extern const int DNS_ERROR; - extern const int UNKNOWN_ADDRESS_PATTERN_TYPE; extern const int UNKNOWN_USER; - extern const int REQUIRED_PASSWORD; - extern const int WRONG_PASSWORD; extern const int IP_ADDRESS_NOT_ALLOWED; - extern const int BAD_ARGUMENTS; - extern const int SUPPORT_IS_DISABLED; } using UserPtr = UsersManager::UserPtr; @@ -61,59 +45,7 @@ UserPtr UsersManager::authorizeAndGetUser( if (!it->second->addresses.contains(address)) throw Exception("User " + user_name + " is not allowed to connect from address " + address.toString(), ErrorCodes::IP_ADDRESS_NOT_ALLOWED); - auto on_wrong_password = [&]() - { - if (password.empty()) - throw Exception("Password required for user " + user_name, ErrorCodes::REQUIRED_PASSWORD); - else - throw Exception("Wrong password for user " + user_name, ErrorCodes::WRONG_PASSWORD); - }; - - if (!it->second->password_sha256_hex.empty()) - { -#if USE_SSL - unsigned char hash[32]; - - SHA256_CTX ctx; - SHA256_Init(&ctx); - SHA256_Update(&ctx, reinterpret_cast(password.data()), password.size()); - SHA256_Final(hash, &ctx); - - String hash_hex; - { - WriteBufferFromString buf(hash_hex); - HexWriteBuffer hex_buf(buf); - hex_buf.write(reinterpret_cast(hash), sizeof(hash)); - } - - Poco::toLowerInPlace(hash_hex); - - if (hash_hex != it->second->password_sha256_hex) - on_wrong_password(); -#else - throw DB::Exception("SHA256 passwords support is disabled, because ClickHouse was built without SSL library", DB::ErrorCodes::SUPPORT_IS_DISABLED); -#endif - } - else if (!it->second->password_double_sha1_hex.empty()) - { - Poco::SHA1Engine engine; - engine.update(password); - const auto & first_sha1 = engine.digest(); - - /// If it was MySQL compatibility server, then first_sha1 already contains double SHA1. - if (Poco::SHA1Engine::digestToHex(first_sha1) == it->second->password_double_sha1_hex) - return it->second; - - engine.update(first_sha1.data(), first_sha1.size()); - - if (Poco::SHA1Engine::digestToHex(engine.digest()) != it->second->password_double_sha1_hex) - on_wrong_password(); - } - else if (password != it->second->password) - { - on_wrong_password(); - } - + it->second->authentication.checkPassword(password, user_name); return it->second; } From 9f6d9d61307e928b47343e0a9b2901daf05f6996 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Tue, 8 Oct 2019 17:40:24 +0300 Subject: [PATCH 2/5] Move the code which checks client host to a separate class. --- dbms/src/Access/AllowedClientHosts.cpp | 397 +++++++++++++++++++++++++ dbms/src/Access/AllowedClientHosts.h | 103 +++++++ dbms/src/Interpreters/Users.cpp | 266 ++--------------- dbms/src/Interpreters/Users.h | 32 +- dbms/src/Interpreters/UsersManager.cpp | 6 +- 5 files changed, 522 insertions(+), 282 deletions(-) create mode 100644 dbms/src/Access/AllowedClientHosts.cpp create mode 100644 dbms/src/Access/AllowedClientHosts.h diff --git a/dbms/src/Access/AllowedClientHosts.cpp b/dbms/src/Access/AllowedClientHosts.cpp new file mode 100644 index 00000000000..4016d0ce00f --- /dev/null +++ b/dbms/src/Access/AllowedClientHosts.cpp @@ -0,0 +1,397 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int DNS_ERROR; + extern const int IP_ADDRESS_NOT_ALLOWED; +} + +namespace +{ + using IPAddress = Poco::Net::IPAddress; + + const AllowedClientHosts::IPSubnet ALL_ADDRESSES = AllowedClientHosts::IPSubnet{IPAddress{IPAddress::IPv6}, IPAddress{IPAddress::IPv6}}; + + IPAddress toIPv6(const IPAddress & addr) + { + if (addr.family() == IPAddress::IPv6) + return addr; + + return IPAddress("::FFFF:" + addr.toString()); + } + + + IPAddress maskToIPv6(const IPAddress & mask) + { + if (mask.family() == IPAddress::IPv6) + return mask; + + return IPAddress(96, IPAddress::IPv6) | toIPv6(mask); + } + + + bool isAddressOfHostImpl(const IPAddress & address, const String & host) + { + IPAddress addr_v6 = toIPv6(address); + + /// Resolve by hand, because Poco don't use AI_ALL flag but we need it. + addrinfo * ai = nullptr; + SCOPE_EXIT( + { + if (ai) + freeaddrinfo(ai); + }); + + addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_flags |= AI_V4MAPPED | AI_ALL; + + int ret = getaddrinfo(host.c_str(), nullptr, &hints, &ai); + if (0 != ret) + throw Exception("Cannot getaddrinfo: " + std::string(gai_strerror(ret)), ErrorCodes::DNS_ERROR); + + for (; ai != nullptr; ai = ai->ai_next) + { + if (ai->ai_addrlen && ai->ai_addr) + { + if (ai->ai_family == AF_INET6) + { + if (addr_v6 == IPAddress( + &reinterpret_cast(ai->ai_addr)->sin6_addr, sizeof(in6_addr), + reinterpret_cast(ai->ai_addr)->sin6_scope_id)) + { + return true; + } + } + else if (ai->ai_family == AF_INET) + { + if (addr_v6 == toIPv6(IPAddress(&reinterpret_cast(ai->ai_addr)->sin_addr, sizeof(in_addr)))) + { + return true; + } + } + } + } + + return false; + } + + + /// Cached version of isAddressOfHostImpl(). We need to cache DNS requests. + bool isAddressOfHost(const IPAddress & address, const String & host) + { + static SimpleCache cache; + return cache(address, host); + } + + + String getHostByAddressImpl(const IPAddress & address) + { + Poco::Net::SocketAddress sock_addr(address, 0); + + /// Resolve by hand, because Poco library doesn't have such functionality. + char host[1024]; + int gai_errno = getnameinfo(sock_addr.addr(), sock_addr.length(), host, sizeof(host), nullptr, 0, NI_NAMEREQD); + if (0 != gai_errno) + throw Exception("Cannot getnameinfo: " + std::string(gai_strerror(gai_errno)), ErrorCodes::DNS_ERROR); + + /// Check that PTR record is resolved back to client address + if (!isAddressOfHost(address, host)) + throw Exception("Host " + String(host) + " isn't resolved back to " + address.toString(), ErrorCodes::DNS_ERROR); + return host; + } + + + /// Cached version of getHostByAddressImpl(). We need to cache DNS requests. + String getHostByAddress(const IPAddress & address) + { + static SimpleCache cache; + return cache(address); + } +} + + +String AllowedClientHosts::IPSubnet::toString() const +{ + unsigned int prefix_length = mask.prefixLength(); + if (IPAddress{prefix_length, mask.family()} == mask) + return prefix.toString() + "/" + std::to_string(prefix_length); + + return prefix.toString() + "/" + mask.toString(); +} + + +AllowedClientHosts::AllowedClientHosts() +{ +} + + +AllowedClientHosts::AllowedClientHosts(AllAddressesTag) +{ + addAllAddresses(); +} + + +AllowedClientHosts::~AllowedClientHosts() = default; + + +AllowedClientHosts::AllowedClientHosts(const AllowedClientHosts & src) +{ + *this = src; +} + + +AllowedClientHosts & AllowedClientHosts::operator =(const AllowedClientHosts & src) +{ + addresses = src.addresses; + subnets = src.subnets; + host_names = src.host_names; + host_regexps = src.host_regexps; + compiled_host_regexps.clear(); + return *this; +} + + +AllowedClientHosts::AllowedClientHosts(AllowedClientHosts && src) +{ + *this = src; +} + + +AllowedClientHosts & AllowedClientHosts::operator =(AllowedClientHosts && src) +{ + addresses = std::move(src.addresses); + subnets = std::move(src.subnets); + host_names = std::move(src.host_names); + host_regexps = std::move(src.host_regexps); + compiled_host_regexps = std::move(src.compiled_host_regexps); + return *this; +} + + +void AllowedClientHosts::clear() +{ + addresses.clear(); + subnets.clear(); + host_names.clear(); + host_regexps.clear(); + compiled_host_regexps.clear(); +} + + +bool AllowedClientHosts::empty() const +{ + return addresses.empty() && subnets.empty() && host_names.empty() && host_regexps.empty(); +} + + +void AllowedClientHosts::addAddress(const IPAddress & address) +{ + IPAddress addr_v6 = toIPv6(address); + if (boost::range::find(addresses, addr_v6) == addresses.end()) + addresses.push_back(addr_v6); +} + + +void AllowedClientHosts::addAddress(const String & address) +{ + addAddress(IPAddress{address}); +} + + +void AllowedClientHosts::addSubnet(const IPSubnet & subnet) +{ + IPSubnet subnet_v6; + subnet_v6.prefix = toIPv6(subnet.prefix); + subnet_v6.mask = maskToIPv6(subnet.mask); + + if (subnet_v6.mask == IPAddress(128, IPAddress::IPv6)) + { + addAddress(subnet_v6.prefix); + return; + } + + subnet_v6.prefix = subnet_v6.prefix & subnet_v6.mask; + + if (boost::range::find(subnets, subnet_v6) == subnets.end()) + subnets.push_back(subnet_v6); +} + + +void AllowedClientHosts::addSubnet(const IPAddress & prefix, const IPAddress & mask) +{ + addSubnet(IPSubnet{prefix, mask}); +} + + +void AllowedClientHosts::addSubnet(const IPAddress & prefix, size_t num_prefix_bits) +{ + addSubnet(prefix, IPAddress(num_prefix_bits, prefix.family())); +} + + +void AllowedClientHosts::addSubnet(const String & subnet) +{ + size_t slash = subnet.find('/'); + if (slash == String::npos) + { + addAddress(subnet); + return; + } + + IPAddress prefix{String{subnet, 0, slash}}; + String mask(subnet, slash + 1, subnet.length() - slash - 1); + if (std::all_of(mask.begin(), mask.end(), isNumericASCII)) + addSubnet(prefix, parseFromString(mask)); + else + addSubnet(prefix, IPAddress{mask}); +} + + +void AllowedClientHosts::addHostName(const String & host_name) +{ + if (boost::range::find(host_names, host_name) == host_names.end()) + host_names.push_back(host_name); +} + + +void AllowedClientHosts::addHostRegexp(const String & host_regexp) +{ + if (boost::range::find(host_regexps, host_regexp) == host_regexps.end()) + host_regexps.push_back(host_regexp); +} + + +void AllowedClientHosts::addAllAddresses() +{ + clear(); + addSubnet(ALL_ADDRESSES); +} + + +bool AllowedClientHosts::containsAllAddresses() const +{ + return (boost::range::find(subnets, ALL_ADDRESSES) != subnets.end()) + || (boost::range::find(host_regexps, ".*") != host_regexps.end()) + || (boost::range::find(host_regexps, "$") != host_regexps.end()); +} + + +bool AllowedClientHosts::contains(const IPAddress & address) const +{ + return containsImpl(address, String(), nullptr); +} + + +void AllowedClientHosts::checkContains(const IPAddress & address, const String & user_name) const +{ + String error; + if (!containsImpl(address, user_name, &error)) + throw Exception(error, ErrorCodes::IP_ADDRESS_NOT_ALLOWED); +} + + +bool AllowedClientHosts::containsImpl(const IPAddress & address, const String & user_name, String * error) const +{ + if (error) + error->clear(); + + /// Check `ip_addresses`. + IPAddress addr_v6 = toIPv6(address); + if (boost::range::find(addresses, addr_v6) != addresses.end()) + return true; + + /// Check `ip_subnets`. + for (const auto & subnet : subnets) + if ((addr_v6 & subnet.mask) == subnet.prefix) + return true; + + /// Check `hosts`. + for (const String & host_name : host_names) + { + try + { + if (isAddressOfHost(address, host_name)) + return true; + } + catch (Exception & e) + { + if (e.code() != ErrorCodes::DNS_ERROR) + e.rethrow(); + + /// Try to ignore DNS errors: if host cannot be resolved, skip it and try next. + LOG_WARNING( + &Logger::get("AddressPatterns"), + "Failed to check if the allowed client hosts contain address " << address.toString() << ". " << e.displayText() + << ", code = " << e.code()); + } + } + + /// Check `host_regexps`. + if (!host_regexps.empty()) + { + compileRegexps(); + try + { + String resolved_host = getHostByAddress(address); + for (const auto & compiled_regexp : compiled_host_regexps) + { + if (compiled_regexp && compiled_regexp->match(resolved_host)) + return true; + } + } + catch (Exception & e) + { + if (e.code() != ErrorCodes::DNS_ERROR) + e.rethrow(); + + /// Try to ignore DNS errors: if host cannot be resolved, skip it and try next. + LOG_WARNING( + &Logger::get("AddressPatterns"), + "Failed to check if the allowed client hosts contain address " << address.toString() << ". " << e.displayText() + << ", code = " << e.code()); + } + } + + if (error) + { + if (user_name.empty()) + *error = "It's not allowed to connect from address " + address.toString(); + else + *error = "User " + user_name + " is not allowed to connect from address " + address.toString(); + } + return false; +} + + +void AllowedClientHosts::compileRegexps() const +{ + if (compiled_host_regexps.size() == host_regexps.size()) + return; + size_t old_size = compiled_host_regexps.size(); + compiled_host_regexps.reserve(host_regexps.size()); + for (size_t i = old_size; i != host_regexps.size(); ++i) + compiled_host_regexps.emplace_back(std::make_unique(host_regexps[i])); +} + + +bool operator ==(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs) +{ + return (lhs.addresses == rhs.addresses) && (lhs.subnets == rhs.subnets) && (lhs.host_names == rhs.host_names) + && (lhs.host_regexps == rhs.host_regexps); +} +} diff --git a/dbms/src/Access/AllowedClientHosts.h b/dbms/src/Access/AllowedClientHosts.h new file mode 100644 index 00000000000..495f4e34d49 --- /dev/null +++ b/dbms/src/Access/AllowedClientHosts.h @@ -0,0 +1,103 @@ +#pragma once + +#include +#include +#include +#include + + +namespace Poco +{ +class RegularExpression; +} + + +namespace DB +{ +/// Represents lists of hosts an user is allowed to connect to server from. +class AllowedClientHosts +{ +public: + using IPAddress = Poco::Net::IPAddress; + + struct IPSubnet + { + IPAddress prefix; + IPAddress mask; + + String toString() const; + + friend bool operator ==(const IPSubnet & lhs, const IPSubnet & rhs) { return (lhs.prefix == rhs.prefix) && (lhs.mask == rhs.mask); } + friend bool operator !=(const IPSubnet & lhs, const IPSubnet & rhs) { return !(lhs == rhs); } + }; + + struct AllAddressesTag {}; + + AllowedClientHosts(); + AllowedClientHosts(AllAddressesTag); + ~AllowedClientHosts(); + + AllowedClientHosts(const AllowedClientHosts & src); + AllowedClientHosts & operator =(const AllowedClientHosts & src); + AllowedClientHosts(AllowedClientHosts && src); + AllowedClientHosts & operator =(AllowedClientHosts && src); + + /// Removes all contained hosts. This will allow all hosts. + void clear(); + bool empty() const; + + /// Allows exact IP address. + /// For example, 213.180.204.3 or 2a02:6b8::3 + void addAddress(const IPAddress & address); + void addAddress(const String & address); + + /// Allows an IP subnet. + void addSubnet(const IPSubnet & subnet); + void addSubnet(const String & subnet); + + /// Allows an IP subnet. + /// For example, 312.234.1.1/255.255.255.0 or 2a02:6b8::3/FFFF:FFFF:FFFF:FFFF:: + void addSubnet(const IPAddress & prefix, const IPAddress & mask); + + /// Allows an IP subnet. + /// For example, 10.0.0.1/8 or 2a02:6b8::3/64 + void addSubnet(const IPAddress & prefix, size_t num_prefix_bits); + + /// Allows all addresses. + void addAllAddresses(); + + /// Allows an exact host. The `contains()` function will check that the provided address equals to one of that host's addresses. + void addHostName(const String & host_name); + + /// Allows a regular expression for the host. + void addHostRegexp(const String & host_regexp); + + const std::vector & getAddresses() const { return addresses; } + const std::vector & getSubnets() const { return subnets; } + const std::vector & getHostNames() const { return host_names; } + const std::vector & getHostRegexps() const { return host_regexps; } + + /// Checks if the provided address is in the list. Returns false if not. + bool contains(const IPAddress & address) const; + + /// Checks if any address is allowed. + bool containsAllAddresses() const; + + /// Checks if the provided address is in the list. Throws an exception if not. + /// `username` is only used for generating an error message if the address isn't in the list. + void checkContains(const IPAddress & address, const String & user_name = String()) const; + + friend bool operator ==(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs); + friend bool operator !=(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs) { return !(lhs == rhs); } + +private: + bool containsImpl(const IPAddress & address, const String & user_name, String * error) const; + void compileRegexps() const; + + std::vector addresses; + std::vector subnets; + std::vector host_names; + std::vector host_regexps; + mutable std::vector> compiled_host_regexps; +}; +} diff --git a/dbms/src/Interpreters/Users.cpp b/dbms/src/Interpreters/Users.cpp index a60defb5dda..35de7b26b71 100644 --- a/dbms/src/Interpreters/Users.cpp +++ b/dbms/src/Interpreters/Users.cpp @@ -1,15 +1,10 @@ #include #include -#include -#include -#include #include #include #include -#include #include #include -#include namespace DB @@ -17,253 +12,12 @@ namespace DB namespace ErrorCodes { - extern const int DNS_ERROR; extern const int UNKNOWN_ADDRESS_PATTERN_TYPE; extern const int UNKNOWN_USER; - extern const int IP_ADDRESS_NOT_ALLOWED; extern const int BAD_ARGUMENTS; } -static Poco::Net::IPAddress toIPv6(const Poco::Net::IPAddress addr) -{ - if (addr.family() == Poco::Net::IPAddress::IPv6) - return addr; - - return Poco::Net::IPAddress("::FFFF:" + addr.toString()); -} - - -/// IP-address or subnet mask. Example: 213.180.204.3 or 10.0.0.1/8 or 312.234.1.1/255.255.255.0 -/// 2a02:6b8::3 or 2a02:6b8::3/64 or 2a02:6b8::3/FFFF:FFFF:FFFF:FFFF:: -class IPAddressPattern : public IAddressPattern -{ -private: - /// Address of mask. Always transformed to IPv6. - Poco::Net::IPAddress mask_address; - /// Mask of net (ip form). Always transformed to IPv6. - Poco::Net::IPAddress subnet_mask; - -public: - explicit IPAddressPattern(const String & str) - { - const char * pos = strchr(str.c_str(), '/'); - - if (nullptr == pos) - { - construct(Poco::Net::IPAddress(str)); - } - else - { - String addr(str, 0, pos - str.c_str()); - auto real_address = Poco::Net::IPAddress(addr); - - String str_mask(str, addr.length() + 1, str.length() - addr.length() - 1); - if (isDigits(str_mask)) - { - UInt8 prefix_bits = parse(pos + 1); - construct(prefix_bits, real_address.family() == Poco::Net::AddressFamily::IPv4); - } - else - { - subnet_mask = netmaskToIPv6(Poco::Net::IPAddress(str_mask)); - } - - mask_address = toIPv6(real_address); - } - } - - bool contains(const Poco::Net::IPAddress & addr) const override - { - return prefixBitsEquals(addr, mask_address, subnet_mask); - } - -private: - void construct(const Poco::Net::IPAddress & mask_address_) - { - mask_address = toIPv6(mask_address_); - subnet_mask = Poco::Net::IPAddress(128, Poco::Net::IPAddress::IPv6); - } - - void construct(UInt8 prefix_bits, bool is_ipv4) - { - prefix_bits = is_ipv4 ? prefix_bits + 96 : prefix_bits; - subnet_mask = Poco::Net::IPAddress(prefix_bits, Poco::Net::IPAddress::IPv6); - } - - static bool prefixBitsEquals(const Poco::Net::IPAddress & ip_address, const Poco::Net::IPAddress & net_address, const Poco::Net::IPAddress & mask) - { - return ((toIPv6(ip_address) & mask) == (toIPv6(net_address) & mask)); - } - - static bool isDigits(const std::string & str) - { - return std::all_of(str.begin(), str.end(), isNumericASCII); - } - - static Poco::Net::IPAddress netmaskToIPv6(Poco::Net::IPAddress mask) - { - if (mask.family() == Poco::Net::IPAddress::IPv6) - return mask; - - return Poco::Net::IPAddress(96, Poco::Net::IPAddress::IPv6) | toIPv6(mask); - } -}; - -/// Check that address equals to one of hostname addresses. -class HostExactPattern : public IAddressPattern -{ -private: - String host; - - static bool containsImpl(const String & host, const Poco::Net::IPAddress & addr) - { - Poco::Net::IPAddress addr_v6 = toIPv6(addr); - - /// Resolve by hand, because Poco don't use AI_ALL flag but we need it. - addrinfo * ai = nullptr; - - addrinfo hints; - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_UNSPEC; - hints.ai_flags |= AI_V4MAPPED | AI_ALL; - - int ret = getaddrinfo(host.c_str(), nullptr, &hints, &ai); - if (0 != ret) - throw Exception("Cannot getaddrinfo: " + std::string(gai_strerror(ret)), ErrorCodes::DNS_ERROR); - - SCOPE_EXIT( - { - freeaddrinfo(ai); - }); - - for (; ai != nullptr; ai = ai->ai_next) - { - if (ai->ai_addrlen && ai->ai_addr) - { - if (ai->ai_family == AF_INET6) - { - if (addr_v6 == Poco::Net::IPAddress( - &reinterpret_cast(ai->ai_addr)->sin6_addr, sizeof(in6_addr), - reinterpret_cast(ai->ai_addr)->sin6_scope_id)) - { - return true; - } - } - else if (ai->ai_family == AF_INET) - { - if (addr_v6 == toIPv6(Poco::Net::IPAddress( - &reinterpret_cast(ai->ai_addr)->sin_addr, sizeof(in_addr)))) - { - return true; - } - } - } - } - - return false; - } - -public: - explicit HostExactPattern(const String & host_) : host(host_) {} - - bool contains(const Poco::Net::IPAddress & addr) const override - { - static SimpleCache cache; - return cache(host, addr); - } -}; - - -/// Check that PTR record for address match the regexp (and in addition, check that PTR record is resolved back to client address). -class HostRegexpPattern : public IAddressPattern -{ -private: - Poco::RegularExpression host_regexp; - - static String getDomain(const Poco::Net::IPAddress & addr) - { - Poco::Net::SocketAddress sock_addr(addr, 0); - - /// Resolve by hand, because Poco library doesn't have such functionality. - char domain[1024]; - int gai_errno = getnameinfo(sock_addr.addr(), sock_addr.length(), domain, sizeof(domain), nullptr, 0, NI_NAMEREQD); - if (0 != gai_errno) - throw Exception("Cannot getnameinfo: " + std::string(gai_strerror(gai_errno)), ErrorCodes::DNS_ERROR); - - return domain; - } - -public: - explicit HostRegexpPattern(const String & host_regexp_) : host_regexp(host_regexp_) {} - - bool contains(const Poco::Net::IPAddress & addr) const override - { - static SimpleCache cache; - - String domain = cache(addr); - Poco::RegularExpression::Match match; - - if (host_regexp.match(domain, match) && HostExactPattern(domain).contains(addr)) - return true; - - return false; - } -}; - - - -bool AddressPatterns::contains(const Poco::Net::IPAddress & addr) const -{ - for (size_t i = 0, size = patterns.size(); i < size; ++i) - { - /// If host cannot be resolved, skip it and try next. - try - { - if (patterns[i]->contains(addr)) - return true; - } - catch (const DB::Exception & e) - { - LOG_WARNING(&Logger::get("AddressPatterns"), - "Failed to check if pattern contains address " << addr.toString() << ". " << e.displayText() << ", code = " << e.code()); - - if (e.code() == ErrorCodes::DNS_ERROR) - { - continue; - } - else - throw; - } - } - - return false; -} - -void AddressPatterns::addFromConfig(const String & config_elem, const Poco::Util::AbstractConfiguration & config) -{ - Poco::Util::AbstractConfiguration::Keys config_keys; - config.keys(config_elem, config_keys); - - for (Poco::Util::AbstractConfiguration::Keys::const_iterator it = config_keys.begin(); it != config_keys.end(); ++it) - { - Container::value_type pattern; - String value = config.getString(config_elem + "." + *it); - - if (startsWith(*it, "ip")) - pattern = std::make_unique(value); - else if (startsWith(*it, "host_regexp")) - pattern = std::make_unique(value); - else if (startsWith(*it, "host")) - pattern = std::make_unique(value); - else - throw Exception("Unknown address pattern type: " + *it, ErrorCodes::UNKNOWN_ADDRESS_PATTERN_TYPE); - - patterns.emplace_back(std::move(pattern)); - } -} - - User::User(const String & name_, const String & config_elem, const Poco::Util::AbstractConfiguration & config) : name(name_) { @@ -299,7 +53,25 @@ User::User(const String & name_, const String & config_elem, const Poco::Util::A profile = config.getString(config_elem + ".profile"); quota = config.getString(config_elem + ".quota"); - addresses.addFromConfig(config_elem + ".networks", config); + /// Fill list of allowed hosts. + const auto config_networks = config_elem + ".networks"; + if (config.has(config_networks)) + { + Poco::Util::AbstractConfiguration::Keys config_keys; + config.keys(config_networks, config_keys); + for (Poco::Util::AbstractConfiguration::Keys::const_iterator it = config_keys.begin(); it != config_keys.end(); ++it) + { + String value = config.getString(config_networks + "." + *it); + if (startsWith(*it, "ip")) + allowed_client_hosts.addSubnet(value); + else if (startsWith(*it, "host_regexp")) + allowed_client_hosts.addHostRegexp(value); + else if (startsWith(*it, "host")) + allowed_client_hosts.addHostName(value); + else + throw Exception("Unknown address pattern type: " + *it, ErrorCodes::UNKNOWN_ADDRESS_PATTERN_TYPE); + } + } /// Fill list of allowed databases. const auto config_sub_elem = config_elem + ".allow_databases"; diff --git a/dbms/src/Interpreters/Users.h b/dbms/src/Interpreters/Users.h index 0e1833427a8..a2d4ccece45 100644 --- a/dbms/src/Interpreters/Users.h +++ b/dbms/src/Interpreters/Users.h @@ -2,20 +2,15 @@ #include #include +#include #include #include #include -#include namespace Poco { - namespace Net - { - class IPAddress; - } - namespace Util { class AbstractConfiguration; @@ -25,29 +20,6 @@ namespace Poco namespace DB { - - -/// Allow to check that address matches a pattern. -class IAddressPattern -{ -public: - virtual bool contains(const Poco::Net::IPAddress & addr) const = 0; - virtual ~IAddressPattern() {} -}; - - -class AddressPatterns -{ -private: - using Container = std::vector>; - Container patterns; - -public: - bool contains(const Poco::Net::IPAddress & addr) const; - void addFromConfig(const String & config_elem, const Poco::Util::AbstractConfiguration & config); -}; - - /** User and ACL. */ struct User @@ -60,7 +32,7 @@ struct User String profile; String quota; - AddressPatterns addresses; + AllowedClientHosts allowed_client_hosts; /// List of allowed databases. using DatabaseSet = std::unordered_set; diff --git a/dbms/src/Interpreters/UsersManager.cpp b/dbms/src/Interpreters/UsersManager.cpp index 4d8cfbadc5b..50b5d6653a3 100644 --- a/dbms/src/Interpreters/UsersManager.cpp +++ b/dbms/src/Interpreters/UsersManager.cpp @@ -1,7 +1,6 @@ #include #include -#include #include @@ -11,7 +10,6 @@ namespace DB namespace ErrorCodes { extern const int UNKNOWN_USER; - extern const int IP_ADDRESS_NOT_ALLOWED; } using UserPtr = UsersManager::UserPtr; @@ -42,9 +40,7 @@ UserPtr UsersManager::authorizeAndGetUser( if (users.end() == it) throw Exception("Unknown user " + user_name, ErrorCodes::UNKNOWN_USER); - if (!it->second->addresses.contains(address)) - throw Exception("User " + user_name + " is not allowed to connect from address " + address.toString(), ErrorCodes::IP_ADDRESS_NOT_ALLOWED); - + it->second->allowed_client_hosts.checkContains(address, user_name); it->second->authentication.checkPassword(password, user_name); return it->second; } From 18ccb4d64d0a5b2603c249a575d20838b948e12a Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Tue, 8 Oct 2019 21:42:22 +0300 Subject: [PATCH 3/5] Move backQuote() and quoteString() to a separate file, use StringRefs. --- dbms/programs/local/LocalServer.cpp | 12 +----- dbms/src/Common/DiskSpaceMonitor.cpp | 2 +- dbms/src/Common/quoteString.cpp | 37 +++++++++++++++++++ dbms/src/Common/quoteString.h | 17 +++++++++ .../CheckConstraintsBlockOutputStream.cpp | 1 + .../ConvertingBlockInputStream.cpp | 1 + dbms/src/Databases/DatabaseOnDisk.h | 1 + dbms/src/IO/WriteHelpers.cpp | 22 ----------- dbms/src/IO/WriteHelpers.h | 35 +++++++----------- .../src/Interpreters/InterpreterDropQuery.cpp | 1 + dbms/src/Interpreters/QueryAliasesVisitor.cpp | 2 +- dbms/src/Interpreters/QueryNormalizer.cpp | 2 +- .../ReplaceQueryParameterVisitor.cpp | 1 + dbms/src/Parsers/ASTAlterQuery.cpp | 6 +-- dbms/src/Parsers/ASTCheckQuery.h | 1 + dbms/src/Parsers/ASTColumnDeclaration.cpp | 1 + dbms/src/Parsers/ASTColumnsMatcher.cpp | 11 ++---- dbms/src/Parsers/ASTConstraintDeclaration.cpp | 2 + dbms/src/Parsers/ASTCreateQuery.cpp | 1 + .../ASTDictionaryAttributeDeclaration.cpp | 2 + dbms/src/Parsers/ASTDropQuery.cpp | 1 + dbms/src/Parsers/ASTIndexDeclaration.h | 5 ++- dbms/src/Parsers/ASTInsertQuery.cpp | 1 + dbms/src/Parsers/ASTNameTypePair.h | 1 + dbms/src/Parsers/ASTOptimizeQuery.cpp | 1 + dbms/src/Parsers/ASTQueryParameter.cpp | 1 + dbms/src/Parsers/ASTQueryWithOnCluster.cpp | 1 + .../Parsers/ASTQueryWithTableAndOutput.cpp | 1 + dbms/src/Parsers/ASTRenameQuery.h | 3 +- dbms/src/Parsers/ASTShowTablesQuery.cpp | 1 + dbms/src/Parsers/ASTSystemQuery.cpp | 1 + dbms/src/Parsers/ASTUseQuery.h | 1 + dbms/src/Parsers/ASTWatchQuery.h | 1 + dbms/src/Parsers/IAST.h | 2 - dbms/src/Parsers/TablePropertiesQueriesASTs.h | 1 + .../Transforms/ConvertingTransform.cpp | 1 + dbms/src/Storages/IStorage.cpp | 1 + dbms/src/Storages/MutationCommands.cpp | 1 + dbms/src/Storages/StorageBuffer.cpp | 1 + .../Storages/getStructureOfRemoteTable.cpp | 1 + .../src/TableFunctions/TableFunctionMySQL.cpp | 2 +- 41 files changed, 113 insertions(+), 74 deletions(-) create mode 100644 dbms/src/Common/quoteString.cpp create mode 100644 dbms/src/Common/quoteString.h diff --git a/dbms/programs/local/LocalServer.cpp b/dbms/programs/local/LocalServer.cpp index f4eac1baec2..c3dfcacf3f3 100644 --- a/dbms/programs/local/LocalServer.cpp +++ b/dbms/programs/local/LocalServer.cpp @@ -19,8 +19,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -221,14 +221,6 @@ catch (const Exception & e) } -inline String getQuotedString(const String & s) -{ - WriteBufferFromOwnString buf; - writeQuotedString(s, buf); - return buf.str(); -} - - std::string LocalServer::getInitialCreateTableQuery() { if (!config().has("table-structure")) @@ -241,7 +233,7 @@ std::string LocalServer::getInitialCreateTableQuery() if (!config().has("table-file") || config().getString("table-file") == "-") /// Use Unix tools stdin naming convention table_file = "stdin"; else /// Use regular file - table_file = getQuotedString(config().getString("table-file")); + table_file = quoteString(config().getString("table-file")); return "CREATE TABLE " + table_name + diff --git a/dbms/src/Common/DiskSpaceMonitor.cpp b/dbms/src/Common/DiskSpaceMonitor.cpp index 967aa34ee40..00a146a809e 100644 --- a/dbms/src/Common/DiskSpaceMonitor.cpp +++ b/dbms/src/Common/DiskSpaceMonitor.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include diff --git a/dbms/src/Common/quoteString.cpp b/dbms/src/Common/quoteString.cpp new file mode 100644 index 00000000000..bcc6906ddfa --- /dev/null +++ b/dbms/src/Common/quoteString.cpp @@ -0,0 +1,37 @@ +#include +#include +#include + + +namespace DB +{ +String quoteString(const StringRef & x) +{ + String res(x.size, '\0'); + WriteBufferFromString wb(res); + writeQuotedString(x, wb); + return res; +} + + +String backQuote(const StringRef & x) +{ + String res(x.size, '\0'); + { + WriteBufferFromString wb(res); + writeBackQuotedString(x, wb); + } + return res; +} + + +String backQuoteIfNeed(const StringRef & x) +{ + String res(x.size, '\0'); + { + WriteBufferFromString wb(res); + writeProbablyBackQuotedString(x, wb); + } + return res; +} +} diff --git a/dbms/src/Common/quoteString.h b/dbms/src/Common/quoteString.h new file mode 100644 index 00000000000..f17f6c7015d --- /dev/null +++ b/dbms/src/Common/quoteString.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include + + +namespace DB +{ +/// Quote the string. +String quoteString(const StringRef & x); + +/// Quote the identifier with backquotes. +String backQuote(const StringRef & x); + +/// Quote the identifier with backquotes, if required. +String backQuoteIfNeed(const StringRef & x); +} diff --git a/dbms/src/DataStreams/CheckConstraintsBlockOutputStream.cpp b/dbms/src/DataStreams/CheckConstraintsBlockOutputStream.cpp index 82cde69ca4e..f771a5cf20c 100644 --- a/dbms/src/DataStreams/CheckConstraintsBlockOutputStream.cpp +++ b/dbms/src/DataStreams/CheckConstraintsBlockOutputStream.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include diff --git a/dbms/src/DataStreams/ConvertingBlockInputStream.cpp b/dbms/src/DataStreams/ConvertingBlockInputStream.cpp index 320bb35f5b3..44f4989f3cc 100644 --- a/dbms/src/DataStreams/ConvertingBlockInputStream.cpp +++ b/dbms/src/DataStreams/ConvertingBlockInputStream.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include diff --git a/dbms/src/Databases/DatabaseOnDisk.h b/dbms/src/Databases/DatabaseOnDisk.h index 761d55bd90b..231db6fdccb 100644 --- a/dbms/src/Databases/DatabaseOnDisk.h +++ b/dbms/src/Databases/DatabaseOnDisk.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include diff --git a/dbms/src/IO/WriteHelpers.cpp b/dbms/src/IO/WriteHelpers.cpp index 0b5bce27b46..fe64983c18a 100644 --- a/dbms/src/IO/WriteHelpers.cpp +++ b/dbms/src/IO/WriteHelpers.cpp @@ -66,26 +66,4 @@ void writeException(const Exception & e, WriteBuffer & buf, bool with_stack_trac if (has_nested) writeException(Exception(Exception::CreateFromPoco, *e.nested()), buf, with_stack_trace); } - - -String backQuoteIfNeed(const String & x) -{ - String res(x.size(), '\0'); - { - WriteBufferFromString wb(res); - writeProbablyBackQuotedString(x, wb); - } - return res; -} - -String backQuote(const String & x) -{ - String res(x.size(), '\0'); - { - WriteBufferFromString wb(res); - writeBackQuotedString(x, wb); - } - return res; -} - } diff --git a/dbms/src/IO/WriteHelpers.h b/dbms/src/IO/WriteHelpers.h index ab3fad08860..49f34595fe1 100644 --- a/dbms/src/IO/WriteHelpers.h +++ b/dbms/src/IO/WriteHelpers.h @@ -410,36 +410,36 @@ inline void writeQuotedString(const StringRef & ref, WriteBuffer & buf) writeAnyQuotedString<'\''>(ref, buf); } -inline void writeDoubleQuotedString(const String & s, WriteBuffer & buf) +inline void writeDoubleQuotedString(const StringRef & s, WriteBuffer & buf) { writeAnyQuotedString<'"'>(s, buf); } /// Outputs a string in backquotes. -inline void writeBackQuotedString(const String & s, WriteBuffer & buf) +inline void writeBackQuotedString(const StringRef & s, WriteBuffer & buf) { writeAnyQuotedString<'`'>(s, buf); } /// Outputs a string in backquotes for MySQL. -inline void writeBackQuotedStringMySQL(const String & s, WriteBuffer & buf) +inline void writeBackQuotedStringMySQL(const StringRef & s, WriteBuffer & buf) { writeChar('`', buf); - writeAnyEscapedString<'`', true>(s.data(), s.data() + s.size(), buf); + writeAnyEscapedString<'`', true>(s.data, s.data + s.size, buf); writeChar('`', buf); } /// The same, but quotes apply only if there are characters that do not match the identifier without quotes. template -inline void writeProbablyQuotedStringImpl(const String & s, WriteBuffer & buf, F && write_quoted_string) +inline void writeProbablyQuotedStringImpl(const StringRef & s, WriteBuffer & buf, F && write_quoted_string) { - if (s.empty() || !isValidIdentifierBegin(s[0])) + if (!s.size || !isValidIdentifierBegin(s.data[0])) write_quoted_string(s, buf); else { - const char * pos = s.data() + 1; - const char * end = s.data() + s.size(); + const char * pos = s.data + 1; + const char * end = s.data + s.size; for (; pos < end; ++pos) if (!isWordCharASCII(*pos)) break; @@ -450,19 +450,19 @@ inline void writeProbablyQuotedStringImpl(const String & s, WriteBuffer & buf, F } } -inline void writeProbablyBackQuotedString(const String & s, WriteBuffer & buf) +inline void writeProbablyBackQuotedString(const StringRef & s, WriteBuffer & buf) { - writeProbablyQuotedStringImpl(s, buf, [](const String & s_, WriteBuffer & buf_) { return writeBackQuotedString(s_, buf_); }); + writeProbablyQuotedStringImpl(s, buf, [](const StringRef & s_, WriteBuffer & buf_) { return writeBackQuotedString(s_, buf_); }); } -inline void writeProbablyDoubleQuotedString(const String & s, WriteBuffer & buf) +inline void writeProbablyDoubleQuotedString(const StringRef & s, WriteBuffer & buf) { - writeProbablyQuotedStringImpl(s, buf, [](const String & s_, WriteBuffer & buf_) { return writeDoubleQuotedString(s_, buf_); }); + writeProbablyQuotedStringImpl(s, buf, [](const StringRef & s_, WriteBuffer & buf_) { return writeDoubleQuotedString(s_, buf_); }); } -inline void writeProbablyBackQuotedStringMySQL(const String & s, WriteBuffer & buf) +inline void writeProbablyBackQuotedStringMySQL(const StringRef & s, WriteBuffer & buf) { - writeProbablyQuotedStringImpl(s, buf, [](const String & s_, WriteBuffer & buf_) { return writeBackQuotedStringMySQL(s_, buf_); }); + writeProbablyQuotedStringImpl(s, buf, [](const StringRef & s_, WriteBuffer & buf_) { return writeBackQuotedStringMySQL(s_, buf_); }); } @@ -905,11 +905,4 @@ inline String toString(const T & x) writeText(x, buf); return buf.str(); } - - -/// Quote the identifier with backquotes, if required. -String backQuoteIfNeed(const String & x); -/// Quote the identifier with backquotes. -String backQuote(const String & x); - } diff --git a/dbms/src/Interpreters/InterpreterDropQuery.cpp b/dbms/src/Interpreters/InterpreterDropQuery.cpp index 7887ebc8892..565863d139a 100644 --- a/dbms/src/Interpreters/InterpreterDropQuery.cpp +++ b/dbms/src/Interpreters/InterpreterDropQuery.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include diff --git a/dbms/src/Interpreters/QueryAliasesVisitor.cpp b/dbms/src/Interpreters/QueryAliasesVisitor.cpp index 98069396d81..6de0ece8b59 100644 --- a/dbms/src/Interpreters/QueryAliasesVisitor.cpp +++ b/dbms/src/Interpreters/QueryAliasesVisitor.cpp @@ -8,7 +8,7 @@ #include #include #include -#include +#include namespace DB { diff --git a/dbms/src/Interpreters/QueryNormalizer.cpp b/dbms/src/Interpreters/QueryNormalizer.cpp index c2991885cf3..e109e4a63fd 100644 --- a/dbms/src/Interpreters/QueryNormalizer.cpp +++ b/dbms/src/Interpreters/QueryNormalizer.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include namespace DB { diff --git a/dbms/src/Interpreters/ReplaceQueryParameterVisitor.cpp b/dbms/src/Interpreters/ReplaceQueryParameterVisitor.cpp index 325499d59d2..1cbcb758bf3 100644 --- a/dbms/src/Interpreters/ReplaceQueryParameterVisitor.cpp +++ b/dbms/src/Interpreters/ReplaceQueryParameterVisitor.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include diff --git a/dbms/src/Parsers/ASTAlterQuery.cpp b/dbms/src/Parsers/ASTAlterQuery.cpp index 69ef80d4a02..93f21ae5c5e 100644 --- a/dbms/src/Parsers/ASTAlterQuery.cpp +++ b/dbms/src/Parsers/ASTAlterQuery.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include namespace DB @@ -183,9 +183,7 @@ void ASTAlterCommand::formatImpl( settings.ostr << "VOLUME "; break; } - WriteBufferFromOwnString move_destination_name_buf; - writeQuoted(move_destination_name, move_destination_name_buf); - settings.ostr << move_destination_name_buf.str(); + settings.ostr << quoteString(move_destination_name); } else if (type == ASTAlterCommand::REPLACE_PARTITION) { diff --git a/dbms/src/Parsers/ASTCheckQuery.h b/dbms/src/Parsers/ASTCheckQuery.h index 40665f6f2b6..e453a82cdb4 100644 --- a/dbms/src/Parsers/ASTCheckQuery.h +++ b/dbms/src/Parsers/ASTCheckQuery.h @@ -2,6 +2,7 @@ #include #include +#include namespace DB diff --git a/dbms/src/Parsers/ASTColumnDeclaration.cpp b/dbms/src/Parsers/ASTColumnDeclaration.cpp index e718d5c292d..b281315f555 100644 --- a/dbms/src/Parsers/ASTColumnDeclaration.cpp +++ b/dbms/src/Parsers/ASTColumnDeclaration.cpp @@ -1,4 +1,5 @@ #include +#include namespace DB diff --git a/dbms/src/Parsers/ASTColumnsMatcher.cpp b/dbms/src/Parsers/ASTColumnsMatcher.cpp index e9cdb822c6e..1dde9507149 100644 --- a/dbms/src/Parsers/ASTColumnsMatcher.cpp +++ b/dbms/src/Parsers/ASTColumnsMatcher.cpp @@ -1,9 +1,6 @@ #include "ASTColumnsMatcher.h" - -#include #include -#include - +#include #include @@ -22,10 +19,8 @@ void ASTColumnsMatcher::appendColumnName(WriteBuffer & ostr) const { writeString void ASTColumnsMatcher::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const { - WriteBufferFromOwnString pattern_quoted; - writeQuotedString(original_pattern, pattern_quoted); - - settings.ostr << (settings.hilite ? hilite_keyword : "") << "COLUMNS" << (settings.hilite ? hilite_none : "") << "(" << pattern_quoted.str() << ")"; + settings.ostr << (settings.hilite ? hilite_keyword : "") << "COLUMNS" << (settings.hilite ? hilite_none : "") << "(" + << quoteString(original_pattern) << ")"; } void ASTColumnsMatcher::setPattern(String pattern) diff --git a/dbms/src/Parsers/ASTConstraintDeclaration.cpp b/dbms/src/Parsers/ASTConstraintDeclaration.cpp index a1b063fc44a..f268141f619 100644 --- a/dbms/src/Parsers/ASTConstraintDeclaration.cpp +++ b/dbms/src/Parsers/ASTConstraintDeclaration.cpp @@ -1,4 +1,6 @@ #include +#include + namespace DB { diff --git a/dbms/src/Parsers/ASTCreateQuery.cpp b/dbms/src/Parsers/ASTCreateQuery.cpp index bdade881b2c..bc4a8290d8d 100644 --- a/dbms/src/Parsers/ASTCreateQuery.cpp +++ b/dbms/src/Parsers/ASTCreateQuery.cpp @@ -3,6 +3,7 @@ #include #include #include +#include namespace DB diff --git a/dbms/src/Parsers/ASTDictionaryAttributeDeclaration.cpp b/dbms/src/Parsers/ASTDictionaryAttributeDeclaration.cpp index ebe0b900ec5..2b056cb3743 100644 --- a/dbms/src/Parsers/ASTDictionaryAttributeDeclaration.cpp +++ b/dbms/src/Parsers/ASTDictionaryAttributeDeclaration.cpp @@ -1,4 +1,6 @@ #include +#include + namespace DB { diff --git a/dbms/src/Parsers/ASTDropQuery.cpp b/dbms/src/Parsers/ASTDropQuery.cpp index b4586bf372c..56d0878ceed 100644 --- a/dbms/src/Parsers/ASTDropQuery.cpp +++ b/dbms/src/Parsers/ASTDropQuery.cpp @@ -1,4 +1,5 @@ #include +#include namespace DB diff --git a/dbms/src/Parsers/ASTIndexDeclaration.h b/dbms/src/Parsers/ASTIndexDeclaration.h index 61e966b3d1b..c71ab21cf57 100644 --- a/dbms/src/Parsers/ASTIndexDeclaration.h +++ b/dbms/src/Parsers/ASTIndexDeclaration.h @@ -2,7 +2,8 @@ #include #include -#include +#include +#include #include #include #include @@ -52,7 +53,7 @@ public: s.ostr << (s.hilite ? hilite_keyword : "") << " TYPE " << (s.hilite ? hilite_none : ""); type->formatImpl(s, state, frame); s.ostr << (s.hilite ? hilite_keyword : "") << " GRANULARITY " << (s.hilite ? hilite_none : ""); - s.ostr << toString(granularity); + s.ostr << granularity; } }; diff --git a/dbms/src/Parsers/ASTInsertQuery.cpp b/dbms/src/Parsers/ASTInsertQuery.cpp index 1ac92f49735..89158fa0649 100644 --- a/dbms/src/Parsers/ASTInsertQuery.cpp +++ b/dbms/src/Parsers/ASTInsertQuery.cpp @@ -1,6 +1,7 @@ #include #include #include +#include namespace DB diff --git a/dbms/src/Parsers/ASTNameTypePair.h b/dbms/src/Parsers/ASTNameTypePair.h index ac72448e2e9..48dd7ae1ac9 100644 --- a/dbms/src/Parsers/ASTNameTypePair.h +++ b/dbms/src/Parsers/ASTNameTypePair.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace DB diff --git a/dbms/src/Parsers/ASTOptimizeQuery.cpp b/dbms/src/Parsers/ASTOptimizeQuery.cpp index 5e95dc41795..92968f2b277 100644 --- a/dbms/src/Parsers/ASTOptimizeQuery.cpp +++ b/dbms/src/Parsers/ASTOptimizeQuery.cpp @@ -1,4 +1,5 @@ #include +#include namespace DB { diff --git a/dbms/src/Parsers/ASTQueryParameter.cpp b/dbms/src/Parsers/ASTQueryParameter.cpp index 462a08b0447..915ecd5e7e4 100644 --- a/dbms/src/Parsers/ASTQueryParameter.cpp +++ b/dbms/src/Parsers/ASTQueryParameter.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace DB diff --git a/dbms/src/Parsers/ASTQueryWithOnCluster.cpp b/dbms/src/Parsers/ASTQueryWithOnCluster.cpp index 9519a33c1e5..b0ccaf8b1fa 100644 --- a/dbms/src/Parsers/ASTQueryWithOnCluster.cpp +++ b/dbms/src/Parsers/ASTQueryWithOnCluster.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include diff --git a/dbms/src/Parsers/ASTQueryWithTableAndOutput.cpp b/dbms/src/Parsers/ASTQueryWithTableAndOutput.cpp index 1e16fb6f0ee..3a776590f80 100644 --- a/dbms/src/Parsers/ASTQueryWithTableAndOutput.cpp +++ b/dbms/src/Parsers/ASTQueryWithTableAndOutput.cpp @@ -1,4 +1,5 @@ #include +#include namespace DB diff --git a/dbms/src/Parsers/ASTRenameQuery.h b/dbms/src/Parsers/ASTRenameQuery.h index 1666873ed9c..4cf007d3b36 100644 --- a/dbms/src/Parsers/ASTRenameQuery.h +++ b/dbms/src/Parsers/ASTRenameQuery.h @@ -3,11 +3,12 @@ #include #include #include +#include + namespace DB { - /** RENAME query */ class ASTRenameQuery : public ASTQueryWithOutput, public ASTQueryWithOnCluster diff --git a/dbms/src/Parsers/ASTShowTablesQuery.cpp b/dbms/src/Parsers/ASTShowTablesQuery.cpp index 4a33aeba99c..34a8c9fb76a 100644 --- a/dbms/src/Parsers/ASTShowTablesQuery.cpp +++ b/dbms/src/Parsers/ASTShowTablesQuery.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace DB diff --git a/dbms/src/Parsers/ASTSystemQuery.cpp b/dbms/src/Parsers/ASTSystemQuery.cpp index b0046b0179b..4e7525bb176 100644 --- a/dbms/src/Parsers/ASTSystemQuery.cpp +++ b/dbms/src/Parsers/ASTSystemQuery.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace DB diff --git a/dbms/src/Parsers/ASTUseQuery.h b/dbms/src/Parsers/ASTUseQuery.h index f1ef1b3b408..2127bf9f2c0 100644 --- a/dbms/src/Parsers/ASTUseQuery.h +++ b/dbms/src/Parsers/ASTUseQuery.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace DB diff --git a/dbms/src/Parsers/ASTWatchQuery.h b/dbms/src/Parsers/ASTWatchQuery.h index 06d1460f038..c4046a8771f 100644 --- a/dbms/src/Parsers/ASTWatchQuery.h +++ b/dbms/src/Parsers/ASTWatchQuery.h @@ -12,6 +12,7 @@ limitations under the License. */ #pragma once #include +#include namespace DB diff --git a/dbms/src/Parsers/IAST.h b/dbms/src/Parsers/IAST.h index c896ed2ce3f..d7c56d80a21 100644 --- a/dbms/src/Parsers/IAST.h +++ b/dbms/src/Parsers/IAST.h @@ -5,7 +5,6 @@ #include #include #include -#include /// backQuote, backQuoteIfNeed #include #include @@ -223,5 +222,4 @@ private: size_t checkDepthImpl(size_t max_depth, size_t level) const; }; - } diff --git a/dbms/src/Parsers/TablePropertiesQueriesASTs.h b/dbms/src/Parsers/TablePropertiesQueriesASTs.h index 1d787d855fc..6a8e3b2ce83 100644 --- a/dbms/src/Parsers/TablePropertiesQueriesASTs.h +++ b/dbms/src/Parsers/TablePropertiesQueriesASTs.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace DB diff --git a/dbms/src/Processors/Transforms/ConvertingTransform.cpp b/dbms/src/Processors/Transforms/ConvertingTransform.cpp index 8729b896084..e801fe7cb26 100644 --- a/dbms/src/Processors/Transforms/ConvertingTransform.cpp +++ b/dbms/src/Processors/Transforms/ConvertingTransform.cpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace DB { diff --git a/dbms/src/Storages/IStorage.cpp b/dbms/src/Storages/IStorage.cpp index f614ff8dc50..4b55cedbfcc 100644 --- a/dbms/src/Storages/IStorage.cpp +++ b/dbms/src/Storages/IStorage.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include diff --git a/dbms/src/Storages/MutationCommands.cpp b/dbms/src/Storages/MutationCommands.cpp index 2358bab6202..f8bc781f166 100644 --- a/dbms/src/Storages/MutationCommands.cpp +++ b/dbms/src/Storages/MutationCommands.cpp @@ -7,6 +7,7 @@ #include #include #include +#include namespace DB diff --git a/dbms/src/Storages/StorageBuffer.cpp b/dbms/src/Storages/StorageBuffer.cpp index bb4ccf8720e..44f2c466a5f 100644 --- a/dbms/src/Storages/StorageBuffer.cpp +++ b/dbms/src/Storages/StorageBuffer.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include diff --git a/dbms/src/Storages/getStructureOfRemoteTable.cpp b/dbms/src/Storages/getStructureOfRemoteTable.cpp index 137abcea649..2b6924695bf 100644 --- a/dbms/src/Storages/getStructureOfRemoteTable.cpp +++ b/dbms/src/Storages/getStructureOfRemoteTable.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include diff --git a/dbms/src/TableFunctions/TableFunctionMySQL.cpp b/dbms/src/TableFunctions/TableFunctionMySQL.cpp index 3cb9b8dea60..820a55c3a2c 100644 --- a/dbms/src/TableFunctions/TableFunctionMySQL.cpp +++ b/dbms/src/TableFunctions/TableFunctionMySQL.cpp @@ -18,9 +18,9 @@ #include #include #include +#include #include #include -#include #include #include From 060257c8c5f3602252db11511e8b4e700f2fee4f Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Fri, 11 Oct 2019 00:48:36 +0300 Subject: [PATCH 4/5] Remove Authentication::setType() function and fix comments. --- dbms/src/Access/AllowedClientHosts.h | 4 ++-- dbms/src/Access/Authentication.h | 3 +-- dbms/src/Interpreters/Users.cpp | 12 +++++------- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/dbms/src/Access/AllowedClientHosts.h b/dbms/src/Access/AllowedClientHosts.h index 495f4e34d49..fea797c2aa4 100644 --- a/dbms/src/Access/AllowedClientHosts.h +++ b/dbms/src/Access/AllowedClientHosts.h @@ -34,7 +34,7 @@ public: struct AllAddressesTag {}; AllowedClientHosts(); - AllowedClientHosts(AllAddressesTag); + explicit AllowedClientHosts(AllAddressesTag); ~AllowedClientHosts(); AllowedClientHosts(const AllowedClientHosts & src); @@ -42,7 +42,7 @@ public: AllowedClientHosts(AllowedClientHosts && src); AllowedClientHosts & operator =(AllowedClientHosts && src); - /// Removes all contained hosts. This will allow all hosts. + /// Removes all contained addresses. This will disallow all addresses. void clear(); bool empty() const; diff --git a/dbms/src/Access/Authentication.h b/dbms/src/Access/Authentication.h index 1f708af985b..d8fae6e03eb 100644 --- a/dbms/src/Access/Authentication.h +++ b/dbms/src/Access/Authentication.h @@ -33,10 +33,9 @@ public: Authentication(Authentication && src) = default; Authentication & operator =(Authentication && src) = default; - void setType(Type type_) { type = type_; } Type getType() const { return type; } - /// Sets the password. This function uses the authentication type set with setType() to encode the password. + /// Sets the password and encrypt it using the authentication type set in the constructor. void setPassword(const String & password); /// Returns the password. Allowed to use only for Type::PLAINTEXT_PASSWORD. diff --git a/dbms/src/Interpreters/Users.cpp b/dbms/src/Interpreters/Users.cpp index 35de7b26b71..8d8704165f4 100644 --- a/dbms/src/Interpreters/Users.cpp +++ b/dbms/src/Interpreters/Users.cpp @@ -34,19 +34,17 @@ User::User(const String & name_, const String & config_elem, const Poco::Util::A if (has_password) { - authentication.setType(Authentication::PLAINTEXT_PASSWORD); + authentication = Authentication{Authentication::PLAINTEXT_PASSWORD}; authentication.setPassword(config.getString(config_elem + ".password")); } - - if (has_password_sha256_hex) + else if (has_password_sha256_hex) { - authentication.setType(Authentication::SHA256_PASSWORD); + authentication = Authentication{Authentication::SHA256_PASSWORD}; authentication.setPasswordHashHex(config.getString(config_elem + ".password_sha256_hex")); } - - if (has_password_double_sha1_hex) + else if (has_password_double_sha1_hex) { - authentication.setType(Authentication::DOUBLE_SHA1_PASSWORD); + authentication = Authentication{Authentication::DOUBLE_SHA1_PASSWORD}; authentication.setPasswordHashHex(config.getString(config_elem + ".password_double_sha1_hex")); } From 9e3815ccefd8b1ba9c9d7c6f87871fd45a7186f9 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Fri, 11 Oct 2019 12:53:10 +0300 Subject: [PATCH 5/5] Fix ubsan issue. --- dbms/src/Access/Authentication.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/dbms/src/Access/Authentication.cpp b/dbms/src/Access/Authentication.cpp index 279cd4978f0..5b641e2906e 100644 --- a/dbms/src/Access/Authentication.cpp +++ b/dbms/src/Access/Authentication.cpp @@ -28,10 +28,7 @@ namespace Digest encodePlainText(const StringRef & text) { - Digest digest; - digest.resize(text.size); - memcpy(digest.data(), text.data, text.size); - return digest; + return Digest(text.data, text.data + text.size); } Digest encodeSHA256(const StringRef & text) @@ -102,7 +99,7 @@ String Authentication::getPassword() const { if (type != PLAINTEXT_PASSWORD) throw Exception("Cannot decode the password", ErrorCodes::LOGICAL_ERROR); - return String(reinterpret_cast(password_hash.data()), password_hash.size()); + return String(password_hash.data(), password_hash.data() + password_hash.size()); }