From c0e1095e66670289cc83d8d104bd508412029a85 Mon Sep 17 00:00:00 2001 From: Arthur Passos Date: Wed, 26 Jun 2024 17:06:38 -0300 Subject: [PATCH] initial fix for session_log, shall be refactored --- src/Access/Authentication.cpp | 23 +---------------------- src/Access/IAccessStorage.cpp | 27 +++++++++++++++------------ src/Access/IAccessStorage.h | 6 +++++- src/Access/LDAPAccessStorage.cpp | 2 +- src/Interpreters/Session.cpp | 4 +++- src/Interpreters/Session.h | 1 + src/Interpreters/SessionLog.cpp | 11 ++++++----- src/Interpreters/SessionLog.h | 5 +++-- 8 files changed, 35 insertions(+), 44 deletions(-) diff --git a/src/Access/Authentication.cpp b/src/Access/Authentication.cpp index 827f2e17e94..7174c66d4b1 100644 --- a/src/Access/Authentication.cpp +++ b/src/Access/Authentication.cpp @@ -249,27 +249,6 @@ namespace return false; } #endif - - [[noreturn]] void throwInvalidCredentialsException(const std::vector & authentication_methods) - { - std::string possible_authentication_types; - bool first = true; - - for (const auto & authentication_method : authentication_methods) - { - if (!first) - { - possible_authentication_types += ", "; - } - possible_authentication_types += toString(authentication_method.getType()); - first = false; - } - - throw Exception( - ErrorCodes::NOT_IMPLEMENTED, - "areCredentialsValid(): Invalid credentials provided, available authentication methods are {}", - possible_authentication_types); - } } bool Authentication::areCredentialsValid( @@ -311,7 +290,7 @@ bool Authentication::areCredentialsValid( if ([[maybe_unused]] const auto * always_allow_credentials = typeid_cast(&credentials)) return true; - throwInvalidCredentialsException(authentication_methods); + return false; } } diff --git a/src/Access/IAccessStorage.cpp b/src/Access/IAccessStorage.cpp index f1725dafdf4..5b29c07e046 100644 --- a/src/Access/IAccessStorage.cpp +++ b/src/Access/IAccessStorage.cpp @@ -521,7 +521,7 @@ std::optional IAccessStorage::authenticateImpl( { if (auto user = tryRead(*id)) { - AuthResult auth_result { .user_id = *id }; + AuthResult auth_result { .user_id = *id, .authentication_data = std::nullopt }; if (!isAddressAllowed(*user, address)) throwAddressNotAllowed(address); @@ -534,12 +534,15 @@ std::optional IAccessStorage::authenticateImpl( if (((auth_type == AuthenticationType::NO_PASSWORD) && !allow_no_password) || ((auth_type == AuthenticationType::PLAINTEXT_PASSWORD) && !allow_plaintext_password)) throwAuthenticationTypeNotAllowed(auth_type); + + if (areCredentialsValid(user->getName(), user->valid_until, auth_method, credentials, external_authenticators, auth_result.settings)) + { + auth_result.authentication_data = auth_method; + return auth_result; + } } - if (!areCredentialsValid(*user, credentials, external_authenticators, auth_result.settings)) - throwInvalidCredentials(); - - return auth_result; + throwInvalidCredentials(); } } @@ -549,9 +552,10 @@ std::optional IAccessStorage::authenticateImpl( return std::nullopt; } - bool IAccessStorage::areCredentialsValid( - const User & user, + const std::string user_name, + time_t valid_until, + const AuthenticationData & authentication_method, const Credentials & credentials, const ExternalAuthenticators & external_authenticators, SettingsChanges & settings) const @@ -559,21 +563,20 @@ bool IAccessStorage::areCredentialsValid( if (!credentials.isReady()) return false; - if (credentials.getUserName() != user.getName()) + if (credentials.getUserName() != user_name) return false; - if (user.valid_until) + if (valid_until) { const time_t now = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); - if (now > user.valid_until) + if (now > valid_until) return false; } - return Authentication::areCredentialsValid(credentials, user.authentication_methods, external_authenticators, settings); + return Authentication::areCredentialsValid(credentials, {authentication_method}, external_authenticators, settings); } - bool IAccessStorage::isAddressAllowed(const User & user, const Poco::Net::IPAddress & address) const { return user.allowed_client_hosts.contains(address); diff --git a/src/Access/IAccessStorage.h b/src/Access/IAccessStorage.h index e88b1601f32..df2e8739f7a 100644 --- a/src/Access/IAccessStorage.h +++ b/src/Access/IAccessStorage.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -34,6 +35,7 @@ struct AuthResult UUID user_id; /// Session settings received from authentication server (if any) SettingsChanges settings{}; + std::optional authentication_data; }; /// Contains entities, i.e. instances of classes derived from IAccessEntity. @@ -227,7 +229,9 @@ protected: bool allow_no_password, bool allow_plaintext_password) const; virtual bool areCredentialsValid( - const User & user, + const std::string user_name, + time_t valid_until, + const AuthenticationData & authentication_method, const Credentials & credentials, const ExternalAuthenticators & external_authenticators, SettingsChanges & settings) const; diff --git a/src/Access/LDAPAccessStorage.cpp b/src/Access/LDAPAccessStorage.cpp index 89920878fef..f6435b0d899 100644 --- a/src/Access/LDAPAccessStorage.cpp +++ b/src/Access/LDAPAccessStorage.cpp @@ -504,7 +504,7 @@ std::optional LDAPAccessStorage::authenticateImpl( } if (id) - return AuthResult{ .user_id = *id }; + return AuthResult{ .user_id = *id, .authentication_data = AuthenticationData(AuthenticationType::LDAP) }; return std::nullopt; } diff --git a/src/Interpreters/Session.cpp b/src/Interpreters/Session.cpp index 8bacc3ac8fc..f56fcb2172f 100644 --- a/src/Interpreters/Session.cpp +++ b/src/Interpreters/Session.cpp @@ -361,6 +361,7 @@ void Session::authenticate(const Credentials & credentials_, const Poco::Net::So { auto auth_result = global_context->getAccessControl().authenticate(credentials_, address.host(), getClientInfo().getLastForwardedFor()); user_id = auth_result.user_id; + user_authenticated_with = auth_result.authentication_data; settings_from_auth_server = auth_result.settings; LOG_DEBUG(log, "{} Authenticated with global context as user {}", toString(auth_id), toString(*user_id)); @@ -705,7 +706,8 @@ void Session::recordLoginSuccess(ContextPtr login_context) const settings, access->getAccess(), getClientInfo(), - user); + user, + *user_authenticated_with); } notified_session_log_about_login = true; diff --git a/src/Interpreters/Session.h b/src/Interpreters/Session.h index f671969b1d4..6281df25de3 100644 --- a/src/Interpreters/Session.h +++ b/src/Interpreters/Session.h @@ -113,6 +113,7 @@ private: mutable UserPtr user; std::optional user_id; + std::optional user_authenticated_with; ContextMutablePtr session_context; mutable bool query_context_created = false; diff --git a/src/Interpreters/SessionLog.cpp b/src/Interpreters/SessionLog.cpp index 4832450af48..30c3bd570f3 100644 --- a/src/Interpreters/SessionLog.cpp +++ b/src/Interpreters/SessionLog.cpp @@ -208,13 +208,13 @@ void SessionLogElement::appendToBlock(MutableColumns & columns) const columns[i++]->insertData(auth_failure_reason.data(), auth_failure_reason.length()); } -// todo arthur fix this method void SessionLog::addLoginSuccess(const UUID & auth_id, const String & session_id, const Settings & settings, const ContextAccessPtr & access, const ClientInfo & client_info, - const UserPtr & login_user) + const UserPtr & login_user, + const AuthenticationData & user_authenticated_with) { SessionLogElement log_entry(auth_id, SESSION_LOGIN_SUCCESS); log_entry.client_info = client_info; @@ -222,10 +222,11 @@ void SessionLog::addLoginSuccess(const UUID & auth_id, if (login_user) { log_entry.user = login_user->getName(); - log_entry.user_identified_with = login_user->authentication_methods.back().getType(); -// log_entry.user_identified_with = login_user->auth_data.getType(); + log_entry.user_identified_with = user_authenticated_with.getType(); } - log_entry.external_auth_server = login_user ? login_user->authentication_methods.back().getLDAPServerName() : ""; + + log_entry.external_auth_server = user_authenticated_with.getLDAPServerName(); + log_entry.session_id = session_id; diff --git a/src/Interpreters/SessionLog.h b/src/Interpreters/SessionLog.h index 5bacb9677c0..2037d6c6861 100644 --- a/src/Interpreters/SessionLog.h +++ b/src/Interpreters/SessionLog.h @@ -22,6 +22,7 @@ class ContextAccess; struct User; using UserPtr = std::shared_ptr; using ContextAccessPtr = std::shared_ptr; +class AuthenticationData; /** A struct which will be inserted as row into session_log table. * @@ -71,14 +72,14 @@ struct SessionLogElement class SessionLog : public SystemLog { using SystemLog::SystemLog; - public: void addLoginSuccess(const UUID & auth_id, const String & session_id, const Settings & settings, const ContextAccessPtr & access, const ClientInfo & client_info, - const UserPtr & login_user); + const UserPtr & login_user, + const AuthenticationData & user_authenticated_with); void addLoginFailure(const UUID & auth_id, const ClientInfo & info, const std::optional & user, const Exception & reason); void addLogOut(const UUID & auth_id, const UserPtr & login_user, const ClientInfo & client_info);