From 39d9c315d4fb12b2606e504068a0777f76387fbf Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Fri, 28 Feb 2020 22:53:18 +0300 Subject: [PATCH 1/2] Simplify the message about authentication: now the same message is used in all cases of failed authentication. --- dbms/src/Access/AccessRightsContext.cpp | 12 ++++++------ dbms/src/Access/AccessRightsContext.h | 4 ++-- dbms/src/Access/AllowedClientHosts.cpp | 13 ------------- dbms/src/Access/AllowedClientHosts.h | 4 ---- dbms/src/Access/Authentication.cpp | 13 ------------- dbms/src/Access/Authentication.h | 4 ---- dbms/src/Common/ErrorCodes.cpp | 1 + dbms/src/Interpreters/Context.cpp | 19 +++++++++++++++---- 8 files changed, 24 insertions(+), 46 deletions(-) diff --git a/dbms/src/Access/AccessRightsContext.cpp b/dbms/src/Access/AccessRightsContext.cpp index 9a32a1234f0..471db279751 100644 --- a/dbms/src/Access/AccessRightsContext.cpp +++ b/dbms/src/Access/AccessRightsContext.cpp @@ -186,20 +186,20 @@ void AccessRightsContext::setRolesInfo(const CurrentRolesInfoPtr & roles_info_) } -void AccessRightsContext::checkPassword(const String & password) const +bool AccessRightsContext::isCorrectPassword(const String & password) const { std::lock_guard lock{mutex}; if (!user) - throw Exception(user_name + ": User has been dropped", ErrorCodes::UNKNOWN_USER); - user->authentication.checkPassword(password, user_name); + return false; + return user->authentication.isCorrectPassword(password); } -void AccessRightsContext::checkHostIsAllowed() const +bool AccessRightsContext::isClientHostAllowed() const { std::lock_guard lock{mutex}; if (!user) - throw Exception(user_name + ": User has been dropped", ErrorCodes::UNKNOWN_USER); - user->allowed_client_hosts.checkContains(params.address, user_name); + return false; + return user->allowed_client_hosts.contains(params.address); } diff --git a/dbms/src/Access/AccessRightsContext.h b/dbms/src/Access/AccessRightsContext.h index f129d70162d..be6101e5e9b 100644 --- a/dbms/src/Access/AccessRightsContext.h +++ b/dbms/src/Access/AccessRightsContext.h @@ -60,8 +60,8 @@ public: UserPtr getUser() const; String getUserName() const; - void checkPassword(const String & password) const; - void checkHostIsAllowed() const; + bool isCorrectPassword(const String & password) const; + bool isClientHostAllowed() const; CurrentRolesInfoPtr getRolesInfo() const; std::vector getCurrentRoles() const; diff --git a/dbms/src/Access/AllowedClientHosts.cpp b/dbms/src/Access/AllowedClientHosts.cpp index d45fdce2354..8ef2a868acc 100644 --- a/dbms/src/Access/AllowedClientHosts.cpp +++ b/dbms/src/Access/AllowedClientHosts.cpp @@ -15,7 +15,6 @@ namespace DB namespace ErrorCodes { extern const int DNS_ERROR; - extern const int IP_ADDRESS_NOT_ALLOWED; } namespace @@ -367,16 +366,4 @@ bool AllowedClientHosts::contains(const IPAddress & client_address) const return false; } - -void AllowedClientHosts::checkContains(const IPAddress & address, const String & user_name) const -{ - if (!contains(address)) - { - if (user_name.empty()) - throw Exception("It's not allowed to connect from address " + address.toString(), ErrorCodes::IP_ADDRESS_NOT_ALLOWED); - else - throw Exception("User " + user_name + " is not allowed to connect from address " + address.toString(), ErrorCodes::IP_ADDRESS_NOT_ALLOWED); - } -} - } diff --git a/dbms/src/Access/AllowedClientHosts.h b/dbms/src/Access/AllowedClientHosts.h index 7eb65a7023b..9e89c2b92a1 100644 --- a/dbms/src/Access/AllowedClientHosts.h +++ b/dbms/src/Access/AllowedClientHosts.h @@ -111,10 +111,6 @@ public: /// Checks if the provided address is in the list. Returns false if not. bool contains(const IPAddress & address) const; - /// Checks if the provided address is in the list. Throws an exception if not. - /// `username` is only used for generating an error message if the address isn't in the list. - void checkContains(const IPAddress & address, const String & user_name = String()) const; - friend bool operator ==(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs); friend bool operator !=(const AllowedClientHosts & lhs, const AllowedClientHosts & rhs) { return !(lhs == rhs); } diff --git a/dbms/src/Access/Authentication.cpp b/dbms/src/Access/Authentication.cpp index 5d2bf8614ba..f435d6e6336 100644 --- a/dbms/src/Access/Authentication.cpp +++ b/dbms/src/Access/Authentication.cpp @@ -9,8 +9,6 @@ namespace ErrorCodes { extern const int LOGICAL_ERROR; extern const int BAD_ARGUMENTS; - extern const int REQUIRED_PASSWORD; - extern const int WRONG_PASSWORD; } @@ -77,15 +75,4 @@ bool Authentication::isCorrectPassword(const String & password_) const throw Exception("Unknown authentication type: " + std::to_string(static_cast(type)), ErrorCodes::LOGICAL_ERROR); } - -void Authentication::checkPassword(const String & password_, const String & user_name) const -{ - if (isCorrectPassword(password_)) - return; - auto info_about_user_name = [&user_name]() { return user_name.empty() ? String() : " for user " + user_name; }; - if (password_.empty() && (type != NO_PASSWORD)) - throw Exception("Password required" + info_about_user_name(), ErrorCodes::REQUIRED_PASSWORD); - throw Exception("Wrong password" + info_about_user_name(), ErrorCodes::WRONG_PASSWORD); -} - } diff --git a/dbms/src/Access/Authentication.h b/dbms/src/Access/Authentication.h index e81ab9fffdf..3f16dc56de3 100644 --- a/dbms/src/Access/Authentication.h +++ b/dbms/src/Access/Authentication.h @@ -70,10 +70,6 @@ public: /// Checks if the provided password is correct. Returns false if not. bool isCorrectPassword(const String & password) const; - /// Checks if the provided password is correct. Throws an exception if not. - /// `user_name` is only used for generating an error message if the password is incorrect. - void checkPassword(const String & password, const String & user_name = String()) const; - friend bool operator ==(const Authentication & lhs, const Authentication & rhs) { return (lhs.type == rhs.type) && (lhs.password_hash == rhs.password_hash); } friend bool operator !=(const Authentication & lhs, const Authentication & rhs) { return !(lhs == rhs); } diff --git a/dbms/src/Common/ErrorCodes.cpp b/dbms/src/Common/ErrorCodes.cpp index 934c7036666..61792b2bb80 100644 --- a/dbms/src/Common/ErrorCodes.cpp +++ b/dbms/src/Common/ErrorCodes.cpp @@ -487,6 +487,7 @@ namespace ErrorCodes extern const int UNKNOWN_PART_TYPE = 513; extern const int ACCESS_STORAGE_FOR_INSERTION_NOT_FOUND = 514; extern const int INCORRECT_ACCESS_ENTITY_DEFINITION = 515; + extern const int AUTHENTICATION_FAILED = 516; extern const int KEEPER_EXCEPTION = 999; extern const int POCO_EXCEPTION = 1000; diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index 572a6b97897..6e1978f03d8 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -92,6 +92,7 @@ namespace ErrorCodes extern const int SESSION_IS_LOCKED; extern const int LOGICAL_ERROR; extern const int UNKNOWN_SCALAR; + extern const int AUTHENTICATION_FAILED; } @@ -646,10 +647,20 @@ void Context::setUser(const String & name, const String & password, const Poco:: if (!quota_key.empty()) client_info.quota_key = quota_key; - auto new_user_id = getAccessControlManager().getID(name); - auto new_access_rights = getAccessControlManager().getAccessRightsContext(new_user_id, {}, true, settings, current_database, client_info); - new_access_rights->checkHostIsAllowed(); - new_access_rights->checkPassword(password); + auto new_user_id = getAccessControlManager().find(name); + AccessRightsContextPtr new_access_rights; + if (new_user_id) + { + new_access_rights = getAccessControlManager().getAccessRightsContext(*new_user_id, {}, true, settings, current_database, client_info); + if (!new_access_rights->isClientHostAllowed() || !new_access_rights->isCorrectPassword(password)) + { + new_user_id = {}; + new_access_rights = nullptr; + } + } + + if (!new_user_id || !new_access_rights) + throw Exception(name + ": Authentication failed: password is incorrect or there is no user with such name", ErrorCodes::AUTHENTICATION_FAILED); user_id = new_user_id; access_rights = std::move(new_access_rights); From 038e52096011b42f1ea2e6004b4449187e95a1ee Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Sat, 29 Feb 2020 15:57:52 +0300 Subject: [PATCH 2/2] Add tests. --- dbms/tests/integration/helpers/client.py | 17 ++++++---- dbms/tests/integration/helpers/cluster.py | 22 +++++++------ .../test_allowed_client_hosts/test.py | 2 +- .../test_authentication/__init__.py | 0 .../integration/test_authentication/test.py | 32 +++++++++++++++++++ .../integration/test_mysql_protocol/test.py | 4 +-- 6 files changed, 57 insertions(+), 20 deletions(-) create mode 100644 dbms/tests/integration/test_authentication/__init__.py create mode 100644 dbms/tests/integration/test_authentication/test.py diff --git a/dbms/tests/integration/helpers/client.py b/dbms/tests/integration/helpers/client.py index e986a9ef7c8..10962cfb724 100644 --- a/dbms/tests/integration/helpers/client.py +++ b/dbms/tests/integration/helpers/client.py @@ -17,11 +17,11 @@ class Client: self.command += ['--host', self.host, '--port', str(self.port), '--stacktrace'] - def query(self, sql, stdin=None, timeout=None, settings=None, user=None, ignore_error=False): - return self.get_query_request(sql, stdin=stdin, timeout=timeout, settings=settings, user=user, ignore_error=ignore_error).get_answer() + def query(self, sql, stdin=None, timeout=None, settings=None, user=None, password=None, ignore_error=False): + return self.get_query_request(sql, stdin=stdin, timeout=timeout, settings=settings, user=user, password=password, ignore_error=ignore_error).get_answer() - def get_query_request(self, sql, stdin=None, timeout=None, settings=None, user=None, ignore_error=False): + def get_query_request(self, sql, stdin=None, timeout=None, settings=None, user=None, password=None, ignore_error=False): command = self.command[:] if stdin is None: @@ -37,15 +37,18 @@ class Client: if user is not None: command += ['--user', user] + if password is not None: + command += ['--password', password] + return CommandRequest(command, stdin, timeout, ignore_error) - def query_and_get_error(self, sql, stdin=None, timeout=None, settings=None, user=None): - return self.get_query_request(sql, stdin=stdin, timeout=timeout, settings=settings, user=user).get_error() + def query_and_get_error(self, sql, stdin=None, timeout=None, settings=None, user=None, password=None): + return self.get_query_request(sql, stdin=stdin, timeout=timeout, settings=settings, user=user, password=password).get_error() - def query_and_get_answer_with_error(self, sql, stdin=None, timeout=None, settings=None, user=None): - return self.get_query_request(sql, stdin=stdin, timeout=timeout, settings=settings, user=user).get_answer_and_error() + def query_and_get_answer_with_error(self, sql, stdin=None, timeout=None, settings=None, user=None, password=None): + return self.get_query_request(sql, stdin=stdin, timeout=timeout, settings=settings, user=user, password=password).get_answer_and_error() class QueryTimeoutExceedException(Exception): pass diff --git a/dbms/tests/integration/helpers/cluster.py b/dbms/tests/integration/helpers/cluster.py index 991a359967b..bc736ee9990 100644 --- a/dbms/tests/integration/helpers/cluster.py +++ b/dbms/tests/integration/helpers/cluster.py @@ -619,15 +619,15 @@ class ClickHouseInstance: self.with_installed_binary = with_installed_binary # Connects to the instance via clickhouse-client, sends a query (1st argument) and returns the answer - def query(self, sql, stdin=None, timeout=None, settings=None, user=None, ignore_error=False): - return self.client.query(sql, stdin, timeout, settings, user, ignore_error) + def query(self, sql, stdin=None, timeout=None, settings=None, user=None, password=None, ignore_error=False): + return self.client.query(sql, stdin, timeout, settings, user, password, ignore_error) - def query_with_retry(self, sql, stdin=None, timeout=None, settings=None, user=None, ignore_error=False, + def query_with_retry(self, sql, stdin=None, timeout=None, settings=None, user=None, password=None, ignore_error=False, retry_count=20, sleep_time=0.5, check_callback=lambda x: True): result = None for i in range(retry_count): try: - result = self.query(sql, stdin, timeout, settings, user, ignore_error) + result = self.query(sql, stdin, timeout, settings, user, password, ignore_error) if check_callback(result): return result time.sleep(sleep_time) @@ -644,15 +644,15 @@ class ClickHouseInstance: return self.client.get_query_request(*args, **kwargs) # Connects to the instance via clickhouse-client, sends a query (1st argument), expects an error and return its code - def query_and_get_error(self, sql, stdin=None, timeout=None, settings=None, user=None): - return self.client.query_and_get_error(sql, stdin, timeout, settings, user) + def query_and_get_error(self, sql, stdin=None, timeout=None, settings=None, user=None, password=None): + return self.client.query_and_get_error(sql, stdin, timeout, settings, user, password) # The same as query_and_get_error but ignores successful query. - def query_and_get_answer_with_error(self, sql, stdin=None, timeout=None, settings=None, user=None): - return self.client.query_and_get_answer_with_error(sql, stdin, timeout, settings, user) + def query_and_get_answer_with_error(self, sql, stdin=None, timeout=None, settings=None, user=None, password=None): + return self.client.query_and_get_answer_with_error(sql, stdin, timeout, settings, user, password) # Connects to the instance via HTTP interface, sends a query and returns the answer - def http_query(self, sql, data=None, params=None, user=None): + def http_query(self, sql, data=None, params=None, user=None, password=None): if params is None: params = {} else: @@ -661,7 +661,9 @@ class ClickHouseInstance: params["query"] = sql auth = "" - if user: + if user and password: + auth = "{}:{}@".format(user, password) + elif user: auth = "{}@".format(user) url = "http://" + auth + self.ip_address + ":8123/?" + urllib.urlencode(params) diff --git a/dbms/tests/integration/test_allowed_client_hosts/test.py b/dbms/tests/integration/test_allowed_client_hosts/test.py index fcdf408c88a..23f7f0a4abd 100644 --- a/dbms/tests/integration/test_allowed_client_hosts/test.py +++ b/dbms/tests/integration/test_allowed_client_hosts/test.py @@ -57,4 +57,4 @@ def test_allowed_host(): for client_node in expected_to_fail: with pytest.raises(Exception) as e: query_from_one_node_to_another(client_node, server, "SELECT * FROM test_table") - assert "User default is not allowed to connect from address" in str(e) + assert "default: Authentication failed" in str(e) diff --git a/dbms/tests/integration/test_authentication/__init__.py b/dbms/tests/integration/test_authentication/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dbms/tests/integration/test_authentication/test.py b/dbms/tests/integration/test_authentication/test.py new file mode 100644 index 00000000000..11ca967fbee --- /dev/null +++ b/dbms/tests/integration/test_authentication/test.py @@ -0,0 +1,32 @@ +import pytest +from helpers.cluster import ClickHouseCluster + +cluster = ClickHouseCluster(__file__) +instance = cluster.add_instance('instance') + + +@pytest.fixture(scope="module", autouse=True) +def setup_nodes(): + try: + cluster.start() + + instance.query("CREATE USER sasha PROFILE 'default'") + instance.query("CREATE USER masha IDENTIFIED BY 'qwerty' PROFILE 'default'") + + yield cluster + + finally: + cluster.shutdown() + + +def test_authentication_pass(): + assert instance.query("SELECT currentUser()", user='sasha') == 'sasha\n' + assert instance.query("SELECT currentUser()", user='masha', password='qwerty') == 'masha\n' + + +def test_authentication_fail(): + # User doesn't exist. + assert "vasya: Authentication failed" in instance.query_and_get_error("SELECT currentUser()", user = 'vasya') + + # Wrong password. + assert "masha: Authentication failed" in instance.query_and_get_error("SELECT currentUser()", user = 'masha', password = '123') diff --git a/dbms/tests/integration/test_mysql_protocol/test.py b/dbms/tests/integration/test_mysql_protocol/test.py index 3f4f4e2a2f8..7987076c29a 100644 --- a/dbms/tests/integration/test_mysql_protocol/test.py +++ b/dbms/tests/integration/test_mysql_protocol/test.py @@ -101,7 +101,7 @@ def test_mysql_client(mysql_client, server_address): '''.format(host=server_address, port=server_port), demux=True) assert stderr == 'mysql: [Warning] Using a password on the command line interface can be insecure.\n' \ - 'ERROR 193 (00000): Wrong password for user default\n' + 'ERROR 516 (00000): default: Authentication failed: password is incorrect or there is no user with such name\n' code, (stdout, stderr) = mysql_client.exec_run(''' mysql --protocol tcp -h {host} -P {port} default -u default --password=123 @@ -179,7 +179,7 @@ def test_python_client(server_address): with pytest.raises(pymysql.InternalError) as exc_info: pymysql.connections.Connection(host=server_address, user='default', password='abacab', database='default', port=server_port) - assert exc_info.value.args == (193, 'Wrong password for user default') + assert exc_info.value.args == (516, 'default: Authentication failed: password is incorrect or there is no user with such name') client = pymysql.connections.Connection(host=server_address, user='default', password='123', database='default', port=server_port)