From 0d377de5f0dcfd0c76035d3e348dfddc7bde7d2d Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Sat, 19 Feb 2022 01:01:30 +0700 Subject: [PATCH] Support syntax CREATE USER IDENTIFIED WITH ssl_certificate CN ... --- src/Access/Authentication.cpp | 2 +- src/Access/Common/AuthenticationData.cpp | 17 +----- src/Access/Common/AuthenticationData.h | 9 ++- src/Access/UsersConfigAccessStorage.cpp | 8 ++- src/Parsers/Access/ASTCreateUserQuery.cpp | 61 +++++++++++++------ src/Parsers/Access/ParserCreateUserQuery.cpp | 18 ++++++ src/Storages/System/StorageSystemUsers.cpp | 18 ++++-- .../test_ssl_cert_authentication/test.py | 20 ++++++ .../01316_create_user_syntax_hilite.reference | 2 +- .../02117_show_create_table_system.reference | 2 +- 10 files changed, 110 insertions(+), 47 deletions(-) diff --git a/src/Access/Authentication.cpp b/src/Access/Authentication.cpp index 7bbf8ec5efa..30a2f25d497 100644 --- a/src/Access/Authentication.cpp +++ b/src/Access/Authentication.cpp @@ -186,7 +186,7 @@ bool Authentication::areCredentialsValid(const Credentials & credentials, const case AuthenticationType::SSL_CERTIFICATE: // N.B. the certificate should only be trusted when 'strict' SSL mode is enabled - if (!auth_data.containsSSLCertificateCommonName(certificate_credentials->getX509CommonName())) + if (!auth_data.getSSLCertificateCommonNames().contains(certificate_credentials->getX509CommonName())) throw Exception("X.509 certificate is not on allowed list", ErrorCodes::WRONG_PASSWORD); return true; diff --git a/src/Access/Common/AuthenticationData.cpp b/src/Access/Common/AuthenticationData.cpp index 6fee41c3dd3..711b91b4098 100644 --- a/src/Access/Common/AuthenticationData.cpp +++ b/src/Access/Common/AuthenticationData.cpp @@ -97,7 +97,8 @@ AuthenticationData::Digest AuthenticationData::Util::encodeSHA1(const std::strin bool operator ==(const AuthenticationData & lhs, const AuthenticationData & rhs) { return (lhs.type == rhs.type) && (lhs.password_hash == rhs.password_hash) - && (lhs.ldap_server_name == rhs.ldap_server_name) && (lhs.kerberos_realm == rhs.kerberos_realm); + && (lhs.ldap_server_name == rhs.ldap_server_name) && (lhs.kerberos_realm == rhs.kerberos_realm) + && (lhs.ssl_certificate_common_names == rhs.ssl_certificate_common_names); } @@ -209,18 +210,4 @@ void AuthenticationData::setPasswordHashBinary(const Digest & hash) throw Exception("setPasswordHashBinary(): authentication type " + toString(type) + " not supported", ErrorCodes::NOT_IMPLEMENTED); } -void AuthenticationData::clearAllowedCertificates() -{ - allowed_certificates.clear(); -} - -void AuthenticationData::addSSLCertificateCommonName(const String & x509CommonName) -{ - allowed_certificates.insert(x509CommonName); -} - -bool AuthenticationData::containsSSLCertificateCommonName(const String & x509CommonName) const -{ - return allowed_certificates.find(x509CommonName) != allowed_certificates.end(); -} } diff --git a/src/Access/Common/AuthenticationData.h b/src/Access/Common/AuthenticationData.h index 152ba55b73d..4d771ce3de3 100644 --- a/src/Access/Common/AuthenticationData.h +++ b/src/Access/Common/AuthenticationData.h @@ -1,7 +1,7 @@ #pragma once #include -#include +#include #include namespace DB @@ -84,9 +84,8 @@ public: const String & getKerberosRealm() const { return kerberos_realm; } void setKerberosRealm(const String & realm) { kerberos_realm = realm; } - void clearAllowedCertificates(); - void addSSLCertificateCommonName(const String & x509CommonName); - bool containsSSLCertificateCommonName(const String & x509CommonName) const; + const boost::container::flat_set & getSSLCertificateCommonNames() const { return ssl_certificate_common_names; } + void setSSLCertificateCommonNames(boost::container::flat_set common_names_) { ssl_certificate_common_names = std::move(common_names_); } friend bool operator ==(const AuthenticationData & lhs, const AuthenticationData & rhs); friend bool operator !=(const AuthenticationData & lhs, const AuthenticationData & rhs) { return !(lhs == rhs); } @@ -106,7 +105,7 @@ private: Digest password_hash; String ldap_server_name; String kerberos_realm; - std::set allowed_certificates; + boost::container::flat_set ssl_certificate_common_names; }; } diff --git a/src/Access/UsersConfigAccessStorage.cpp b/src/Access/UsersConfigAccessStorage.cpp index 5d7933190d5..fd10150623b 100644 --- a/src/Access/UsersConfigAccessStorage.cpp +++ b/src/Access/UsersConfigAccessStorage.cpp @@ -115,16 +115,18 @@ namespace /// Fill list of allowed certificates. Poco::Util::AbstractConfiguration::Keys keys; config.keys(certificates_config, keys); - user->auth_data.clearAllowedCertificates(); + boost::container::flat_set common_names; for (const String & key : keys) { if (key.starts_with("common_name")) { String value = config.getString(certificates_config + "." + key); - user->auth_data.addSSLCertificateCommonName(value); + common_names.insert(std::move(value)); } else - throw Exception("Unknown certificate pattern type: " + key, ErrorCodes::BAD_ARGUMENTS); } + throw Exception("Unknown certificate pattern type: " + key, ErrorCodes::BAD_ARGUMENTS); + } + user->auth_data.setSSLCertificateCommonNames(std::move(common_names)); } const auto profile_name_config = user_config + ".profile"; diff --git a/src/Parsers/Access/ASTCreateUserQuery.cpp b/src/Parsers/Access/ASTCreateUserQuery.cpp index 70f29a02d85..f8e1109886e 100644 --- a/src/Parsers/Access/ASTCreateUserQuery.cpp +++ b/src/Parsers/Access/ASTCreateUserQuery.cpp @@ -34,51 +34,62 @@ namespace } String auth_type_name = AuthenticationTypeInfo::get(auth_type).name; - String by_keyword = "BY"; - std::optional by_value; + String value_prefix; + std::optional value; + const boost::container::flat_set * values = nullptr; - if ( - show_password || + if (show_password || auth_type == AuthenticationType::LDAP || - auth_type == AuthenticationType::KERBEROS - ) + auth_type == AuthenticationType::KERBEROS || + auth_type == AuthenticationType::SSL_CERTIFICATE) { switch (auth_type) { case AuthenticationType::PLAINTEXT_PASSWORD: { - by_value = auth_data.getPassword(); + value_prefix = "BY"; + value = auth_data.getPassword(); break; } case AuthenticationType::SHA256_PASSWORD: { auth_type_name = "sha256_hash"; - by_value = auth_data.getPasswordHashHex(); + value_prefix = "BY"; + value = auth_data.getPasswordHashHex(); break; } case AuthenticationType::DOUBLE_SHA1_PASSWORD: { auth_type_name = "double_sha1_hash"; - by_value = auth_data.getPasswordHashHex(); + value_prefix = "BY"; + value = auth_data.getPasswordHashHex(); break; } case AuthenticationType::LDAP: { - by_keyword = "SERVER"; - by_value = auth_data.getLDAPServerName(); + value_prefix = "SERVER"; + value = auth_data.getLDAPServerName(); break; } case AuthenticationType::KERBEROS: { - by_keyword = "REALM"; const auto & realm = auth_data.getKerberosRealm(); if (!realm.empty()) - by_value = realm; + { + value_prefix = "REALM"; + value = realm; + } + break; + } + + case AuthenticationType::SSL_CERTIFICATE: + { + value_prefix = "CN"; + values = &auth_data.getSSLCertificateCommonNames(); break; } case AuthenticationType::NO_PASSWORD: [[fallthrough]]; - case AuthenticationType::SSL_CERTIFICATE: [[fallthrough]]; case AuthenticationType::MAX: throw Exception("AST: Unexpected authentication type " + toString(auth_type), ErrorCodes::LOGICAL_ERROR); } @@ -87,10 +98,26 @@ namespace settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " IDENTIFIED WITH " << auth_type_name << (settings.hilite ? IAST::hilite_none : ""); - if (by_value) + if (!value_prefix.empty()) { - settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " " << by_keyword << " " - << (settings.hilite ? IAST::hilite_none : "") << quoteString(*by_value); + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " " << value_prefix + << (settings.hilite ? IAST::hilite_none : ""); + } + + if (value) + { + settings.ostr << " " << quoteString(*value); + } + else if (values) + { + settings.ostr << " "; + bool need_comma = false; + for (const auto & item : *values) + { + if (std::exchange(need_comma, true)) + settings.ostr << ", "; + settings.ostr << quoteString(item); + } } } diff --git a/src/Parsers/Access/ParserCreateUserQuery.cpp b/src/Parsers/Access/ParserCreateUserQuery.cpp index c5b8c9e37b3..cde14e632dd 100644 --- a/src/Parsers/Access/ParserCreateUserQuery.cpp +++ b/src/Parsers/Access/ParserCreateUserQuery.cpp @@ -52,6 +52,7 @@ namespace bool expect_hash = false; bool expect_ldap_server_name = false; bool expect_kerberos_realm = false; + bool expect_common_names = false; if (ParserKeyword{"WITH"}.ignore(pos, expected)) { @@ -65,6 +66,8 @@ namespace expect_ldap_server_name = true; else if (check_type == AuthenticationType::KERBEROS) expect_kerberos_realm = true; + else if (check_type == AuthenticationType::SSL_CERTIFICATE) + expect_common_names = true; else if (check_type != AuthenticationType::NO_PASSWORD) expect_password = true; @@ -96,6 +99,7 @@ namespace } String value; + boost::container::flat_set common_names; if (expect_password || expect_hash) { ASTPtr ast; @@ -123,6 +127,18 @@ namespace value = ast->as().value.safeGet(); } } + else if (expect_common_names) + { + if (!ParserKeyword{"CN"}.ignore(pos, expected)) + return false; + + ASTPtr ast; + if (!ParserList{std::make_unique(), std::make_unique(TokenType::Comma), false}.parse(pos, ast, expected)) + return false; + + for (const auto & ast_child : ast->children) + common_names.insert(ast_child->as().value.safeGet()); + } auth_data = AuthenticationData{*type}; if (expect_password) @@ -133,6 +149,8 @@ namespace auth_data.setLDAPServerName(value); else if (expect_kerberos_realm) auth_data.setKerberosRealm(value); + else if (expect_common_names) + auth_data.setSSLCertificateCommonNames(std::move(common_names)); return true; }); diff --git a/src/Storages/System/StorageSystemUsers.cpp b/src/Storages/System/StorageSystemUsers.cpp index ca88fa688a0..d9b94f21c61 100644 --- a/src/Storages/System/StorageSystemUsers.cpp +++ b/src/Storages/System/StorageSystemUsers.cpp @@ -102,17 +102,27 @@ void StorageSystemUsers::fillData(MutableColumns & res_columns, ContextPtr conte column_storage.insertData(storage_name.data(), storage_name.length()); column_auth_type.push_back(static_cast(auth_data.getType())); - if ( - auth_data.getType() == AuthenticationType::LDAP || - auth_data.getType() == AuthenticationType::KERBEROS - ) + if (auth_data.getType() == AuthenticationType::LDAP || + auth_data.getType() == AuthenticationType::KERBEROS || + auth_data.getType() == AuthenticationType::SSL_CERTIFICATE) { Poco::JSON::Object auth_params_json; if (auth_data.getType() == AuthenticationType::LDAP) + { auth_params_json.set("server", auth_data.getLDAPServerName()); + } else if (auth_data.getType() == AuthenticationType::KERBEROS) + { auth_params_json.set("realm", auth_data.getKerberosRealm()); + } + else if (auth_data.getType() == AuthenticationType::SSL_CERTIFICATE) + { + Poco::JSON::Array::Ptr arr = new Poco::JSON::Array(); + for (const auto & common_name : auth_data.getSSLCertificateCommonNames()) + arr->add(common_name); + auth_params_json.set("common_names", arr); + } std::ostringstream oss; // STYLE_CHECK_ALLOW_STD_STRING_STREAM oss.exceptions(std::ios::failbit); diff --git a/tests/integration/test_ssl_cert_authentication/test.py b/tests/integration/test_ssl_cert_authentication/test.py index 4ac76f3dea6..bdc7310a6b0 100644 --- a/tests/integration/test_ssl_cert_authentication/test.py +++ b/tests/integration/test_ssl_cert_authentication/test.py @@ -95,3 +95,23 @@ def test_https_non_ssl_auth(): with pytest.raises(Exception) as err: execute_query_https("SELECT currentUser()", user="jane", enable_ssl_auth=False, password='qwe123', cert_name='wrong') assert "unknown ca" in str(err.value) + + +def test_create_user(): + instance.query("CREATE USER emma IDENTIFIED WITH ssl_certificate CN 'client3'") + assert execute_query_https("SELECT currentUser()", user="emma", cert_name='client3') == "emma\n" + assert instance.query("SHOW CREATE USER emma") == "CREATE USER emma IDENTIFIED WITH ssl_certificate CN \\'client3\\'\n" + + instance.query("ALTER USER emma IDENTIFIED WITH ssl_certificate CN 'client2'") + assert execute_query_https("SELECT currentUser()", user="emma", cert_name='client2') == "emma\n" + assert instance.query("SHOW CREATE USER emma") == "CREATE USER emma IDENTIFIED WITH ssl_certificate CN \\'client2\\'\n" + + with pytest.raises(Exception) as err: + execute_query_https("SELECT currentUser()", user="emma", cert_name='client3') + assert "HTTP Error 403" in str(err.value) + + assert instance.query("SHOW CREATE USER lucy") == "CREATE USER lucy IDENTIFIED WITH ssl_certificate CN \\'client2\\', \\'client3\\'\n" + + assert instance.query("SELECT name, auth_type, auth_params FROM system.users WHERE name IN ['emma', 'lucy'] ORDER BY name") ==\ + "emma\tssl_certificate\t{\"common_names\":[\"client2\"]}\n"\ + "lucy\tssl_certificate\t{\"common_names\":[\"client2\",\"client3\"]}\n" diff --git a/tests/queries/0_stateless/01316_create_user_syntax_hilite.reference b/tests/queries/0_stateless/01316_create_user_syntax_hilite.reference index ed7daeb3609..d1e2cba5663 100644 --- a/tests/queries/0_stateless/01316_create_user_syntax_hilite.reference +++ b/tests/queries/0_stateless/01316_create_user_syntax_hilite.reference @@ -1 +1 @@ -CREATE USER user IDENTIFIED WITH plaintext_password BY 'hello' +CREATE USER user IDENTIFIED WITH plaintext_password BY 'hello' diff --git a/tests/queries/0_stateless/02117_show_create_table_system.reference b/tests/queries/0_stateless/02117_show_create_table_system.reference index 234804f1078..44dbf86ee9c 100644 --- a/tests/queries/0_stateless/02117_show_create_table_system.reference +++ b/tests/queries/0_stateless/02117_show_create_table_system.reference @@ -60,7 +60,7 @@ CREATE TABLE system.table_functions\n(\n `name` String\n)\nENGINE = SystemTab CREATE TABLE system.tables\n(\n `database` String,\n `name` String,\n `uuid` UUID,\n `engine` String,\n `is_temporary` UInt8,\n `data_paths` Array(String),\n `metadata_path` String,\n `metadata_modification_time` DateTime,\n `dependencies_database` Array(String),\n `dependencies_table` Array(String),\n `create_table_query` String,\n `engine_full` String,\n `as_select` String,\n `partition_key` String,\n `sorting_key` String,\n `primary_key` String,\n `sampling_key` String,\n `storage_policy` String,\n `total_rows` Nullable(UInt64),\n `total_bytes` Nullable(UInt64),\n `lifetime_rows` Nullable(UInt64),\n `lifetime_bytes` Nullable(UInt64),\n `comment` String,\n `has_own_data` UInt8,\n `loading_dependencies_database` Array(String),\n `loading_dependencies_table` Array(String),\n `loading_dependent_database` Array(String),\n `loading_dependent_table` Array(String),\n `table` String\n)\nENGINE = SystemTables()\nCOMMENT \'SYSTEM TABLE is built on the fly.\' CREATE TABLE system.time_zones\n(\n `time_zone` String\n)\nENGINE = SystemTimeZones()\nCOMMENT \'SYSTEM TABLE is built on the fly.\' CREATE TABLE system.user_directories\n(\n `name` String,\n `type` String,\n `params` String,\n `precedence` UInt64\n)\nENGINE = SystemUserDirectories()\nCOMMENT \'SYSTEM TABLE is built on the fly.\' -CREATE TABLE system.users\n(\n `name` String,\n `id` UUID,\n `storage` String,\n `auth_type` Enum8(\'no_password\' = 0, \'plaintext_password\' = 1, \'sha256_password\' = 2, \'double_sha1_password\' = 3, \'ldap\' = 4, \'kerberos\' = 5),\n `auth_params` String,\n `host_ip` Array(String),\n `host_names` Array(String),\n `host_names_regexp` Array(String),\n `host_names_like` Array(String),\n `default_roles_all` UInt8,\n `default_roles_list` Array(String),\n `default_roles_except` Array(String),\n `grantees_any` UInt8,\n `grantees_list` Array(String),\n `grantees_except` Array(String),\n `default_database` String\n)\nENGINE = SystemUsers()\nCOMMENT \'SYSTEM TABLE is built on the fly.\' +CREATE TABLE system.users\n(\n `name` String,\n `id` UUID,\n `storage` String,\n `auth_type` Enum8(\'no_password\' = 0, \'plaintext_password\' = 1, \'sha256_password\' = 2, \'double_sha1_password\' = 3, \'ldap\' = 4, \'kerberos\' = 5, \'ssl_certificate\' = 6),\n `auth_params` String,\n `host_ip` Array(String),\n `host_names` Array(String),\n `host_names_regexp` Array(String),\n `host_names_like` Array(String),\n `default_roles_all` UInt8,\n `default_roles_list` Array(String),\n `default_roles_except` Array(String),\n `grantees_any` UInt8,\n `grantees_list` Array(String),\n `grantees_except` Array(String),\n `default_database` String\n)\nENGINE = SystemUsers()\nCOMMENT \'SYSTEM TABLE is built on the fly.\' CREATE TABLE system.warnings\n(\n `message` String\n)\nENGINE = SystemWarnings()\nCOMMENT \'SYSTEM TABLE is built on the fly.\' CREATE TABLE system.zeros\n(\n `zero` UInt8\n)\nENGINE = SystemZeros()\nCOMMENT \'SYSTEM TABLE is built on the fly.\' CREATE TABLE system.zeros_mt\n(\n `zero` UInt8\n)\nENGINE = SystemZeros()\nCOMMENT \'SYSTEM TABLE is built on the fly.\'