append default constructed auth method upon alter without auth data

This commit is contained in:
Arthur Passos 2024-06-14 10:58:57 -03:00
parent 98e5ea5206
commit c1250ccb35
2 changed files with 98 additions and 59 deletions

View File

@ -83,38 +83,30 @@ namespace
return false; return false;
} }
#endif #endif
}
static std::vector<AuthenticationData> getAuthenticationMethodsOfType( std::vector<AuthenticationData> getAuthenticationMethodsOfType(
const std::vector<AuthenticationData> & authentication_methods, const std::vector<AuthenticationData> & authentication_methods,
const std::unordered_set<AuthenticationType> & types) const std::unordered_set<AuthenticationType> & types)
{
std::vector<AuthenticationData> authentication_methods_of_type;
for (const auto & authentication_method : authentication_methods)
{ {
if (types.contains(authentication_method.getType())) std::vector<AuthenticationData> authentication_methods_of_type;
for (const auto & authentication_method : authentication_methods)
{ {
authentication_methods_of_type.push_back(authentication_method); if (types.contains(authentication_method.getType()))
{
authentication_methods_of_type.push_back(authentication_method);
}
} }
return authentication_methods_of_type;
} }
return authentication_methods_of_type; bool checkKerberosAuthentication(
} const GSSAcceptorContext * gss_acceptor_context,
const std::vector<AuthenticationData> & authentication_methods,
bool Authentication::areCredentialsValid( const ExternalAuthenticators & external_authenticators)
const Credentials & credentials,
const std::vector<AuthenticationData> & authentication_methods,
const ExternalAuthenticators & external_authenticators,
SettingsChanges & settings)
{
if (!credentials.isReady())
return false;
if (const auto * gss_acceptor_context = typeid_cast<const GSSAcceptorContext *>(&credentials))
{ {
auto kerberos_authentication_methods = getAuthenticationMethodsOfType(authentication_methods, {AuthenticationType::KERBEROS}); auto kerberos_authentication_methods = getAuthenticationMethodsOfType(authentication_methods, {AuthenticationType::KERBEROS});
for (const auto & kerberos_authentication : kerberos_authentication_methods) for (const auto & kerberos_authentication : kerberos_authentication_methods)
{ {
if (external_authenticators.checkKerberosCredentials(kerberos_authentication.getKerberosRealm(), *gss_acceptor_context)) if (external_authenticators.checkKerberosCredentials(kerberos_authentication.getKerberosRealm(), *gss_acceptor_context))
@ -122,11 +114,12 @@ bool Authentication::areCredentialsValid(
return true; return true;
} }
} }
return false; return false;
} }
if (const auto * mysql_credentials = typeid_cast<const MySQLNative41Credentials *>(&credentials)) bool checkMySQLAuthentication(
const MySQLNative41Credentials * mysql_credentials,
const std::vector<AuthenticationData> & authentication_methods)
{ {
auto mysql_authentication_methods = getAuthenticationMethodsOfType( auto mysql_authentication_methods = getAuthenticationMethodsOfType(
authentication_methods, authentication_methods,
@ -138,7 +131,7 @@ bool Authentication::areCredentialsValid(
{ {
case AuthenticationType::PLAINTEXT_PASSWORD: case AuthenticationType::PLAINTEXT_PASSWORD:
if (checkPasswordPlainTextMySQL( if (checkPasswordPlainTextMySQL(
mysql_credentials->getScramble(), mysql_credentials->getScrambledPassword(), mysql_authentication_method.getPasswordHashBinary())) mysql_credentials->getScramble(), mysql_credentials->getScrambledPassword(), mysql_authentication_method.getPasswordHashBinary()))
{ {
return true; return true;
} }
@ -154,14 +147,17 @@ bool Authentication::areCredentialsValid(
} }
break; break;
default: default:
throw Exception(ErrorCodes::LOGICAL_ERROR, "something bad happened"); throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid MySQL authentication type");
} }
} }
return false; return false;
} }
if (const auto * basic_credentials = typeid_cast<const BasicCredentials *>(&credentials)) bool checkBasicAuthentication(
const BasicCredentials * basic_credentials,
const std::vector<AuthenticationData> & authentication_methods,
const ExternalAuthenticators & external_authenticators,
SettingsChanges & settings)
{ {
auto basic_credentials_authentication_methods = getAuthenticationMethodsOfType( auto basic_credentials_authentication_methods = getAuthenticationMethodsOfType(
authentication_methods, authentication_methods,
@ -212,17 +208,17 @@ bool Authentication::areCredentialsValid(
} }
break; break;
default: default:
throw Exception(ErrorCodes::LOGICAL_ERROR, "something bad happened"); throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid basic authentication type");
} }
} }
return false; return false;
} }
if (const auto * ssl_certificate_credentials = typeid_cast<const SSLCertificateCredentials *>(&credentials)) bool checkSSLCertificateAuthentication(
const SSLCertificateCredentials * ssl_certificate_credentials,
const std::vector<AuthenticationData> & authentication_methods)
{ {
const auto ssl_certificate_authentication_methods = getAuthenticationMethodsOfType(authentication_methods, {AuthenticationType::SSL_CERTIFICATE}); const auto ssl_certificate_authentication_methods = getAuthenticationMethodsOfType(authentication_methods, {AuthenticationType::SSL_CERTIFICATE});
for (const auto & auth_method : ssl_certificate_authentication_methods) for (const auto & auth_method : ssl_certificate_authentication_methods)
{ {
if (auth_method.getSSLCertificateCommonNames().contains(ssl_certificate_credentials->getCommonName())) if (auth_method.getSSLCertificateCommonNames().contains(ssl_certificate_credentials->getCommonName()))
@ -230,15 +226,15 @@ bool Authentication::areCredentialsValid(
return true; return true;
} }
} }
return false; return false;
} }
#if USE_SSH #if USE_SSH
if (const auto * ssh_credentials = typeid_cast<const SshCredentials *>(&credentials)) bool checkSshAuthentication(
const SshCredentials * ssh_credentials,
const std::vector<AuthenticationData> & authentication_methods)
{ {
const auto ssh_authentication_methods = getAuthenticationMethodsOfType(authentication_methods, {AuthenticationType::SSL_CERTIFICATE}); const auto ssh_authentication_methods = getAuthenticationMethodsOfType(authentication_methods, {AuthenticationType::SSH_KEY});
for (const auto & auth_method : ssh_authentication_methods) for (const auto & auth_method : ssh_authentication_methods)
{ {
if (checkSshSignature(auth_method.getSSHKeys(), ssh_credentials->getSignature(), ssh_credentials->getOriginal())) if (checkSshSignature(auth_method.getSSHKeys(), ssh_credentials->getSignature(), ssh_credentials->getOriginal()))
@ -246,35 +242,73 @@ bool Authentication::areCredentialsValid(
return true; return true;
} }
} }
return false; return false;
} }
#endif #endif
[[noreturn]] void throwInvalidCredentialsException(const std::vector<AuthenticationData> & authentication_methods)
{
std::string possible_authentication_types;
bool first = true;
for (const auto & authentication_method : authentication_methods)
{
if (!first)
{
possible_authentication_types += ", ";
}
possible_authentication_types += toString(authentication_method.getType());
first = false;
}
throw Exception(
ErrorCodes::NOT_IMPLEMENTED,
"areCredentialsValid(): Invalid credentials provided, available authentication methods are {}",
possible_authentication_types);
}
}
bool Authentication::areCredentialsValid(
const Credentials & credentials,
const std::vector<AuthenticationData> & authentication_methods,
const ExternalAuthenticators & external_authenticators,
SettingsChanges & settings)
{
if (!credentials.isReady())
return false;
if (const auto * gss_acceptor_context = typeid_cast<const GSSAcceptorContext *>(&credentials))
{
return checkKerberosAuthentication(gss_acceptor_context, authentication_methods, external_authenticators);
}
if (const auto * mysql_credentials = typeid_cast<const MySQLNative41Credentials *>(&credentials))
{
return checkMySQLAuthentication(mysql_credentials, authentication_methods);
}
if (const auto * basic_credentials = typeid_cast<const BasicCredentials *>(&credentials))
{
return checkBasicAuthentication(basic_credentials, authentication_methods, external_authenticators, settings);
}
if (const auto * ssl_certificate_credentials = typeid_cast<const SSLCertificateCredentials *>(&credentials))
{
return checkSSLCertificateAuthentication(ssl_certificate_credentials, authentication_methods);
}
#if USE_SSH
if (const auto * ssh_credentials = typeid_cast<const SshCredentials *>(&credentials))
{
return checkSshAuthentication(ssh_credentials, authentication_methods);
}
#endif
if ([[maybe_unused]] const auto * always_allow_credentials = typeid_cast<const AlwaysAllowCredentials *>(&credentials)) if ([[maybe_unused]] const auto * always_allow_credentials = typeid_cast<const AlwaysAllowCredentials *>(&credentials))
return true; return true;
// below code sucks, but works for now I guess. throwInvalidCredentialsException(authentication_methods);
// might be a problem if no auth method has been registered
std::string possible_authentication_types;
bool first = true;
for (const auto & authentication_method : authentication_methods)
{
if (first)
{
possible_authentication_types += ", ";
first = false;
}
possible_authentication_types += toString(authentication_method.getType());
}
throw Exception(
ErrorCodes::NOT_IMPLEMENTED,
"areCredentialsValid(): Invalid credentials provided, available authentication methods are {}",
possible_authentication_types);
} }
} }

View File

@ -61,6 +61,11 @@ namespace
{ {
user.authentication_methods.push_back(*auth_data); user.authentication_methods.push_back(*auth_data);
} }
else if (user.authentication_methods.empty())
{
// previously, a user always had a default constructed auth method.. maybe I should put this somewhere else
user.authentication_methods.emplace_back();
}
if (reset_authentication_methods) if (reset_authentication_methods)
{ {