small refactor

This commit is contained in:
Arthur Passos 2024-06-27 11:50:16 -03:00
parent 21de3e2961
commit b219b99380
3 changed files with 67 additions and 137 deletions

View File

@ -83,176 +83,107 @@ namespace
}
#endif
std::vector<AuthenticationData> getAuthenticationMethodsOfType(
const std::vector<AuthenticationData> & authentication_methods,
const std::unordered_set<AuthenticationType> & types)
{
std::vector<AuthenticationData> authentication_methods_of_type;
for (const auto & authentication_method : authentication_methods)
{
if (types.contains(authentication_method.getType()))
{
authentication_methods_of_type.push_back(authentication_method);
}
}
return authentication_methods_of_type;
}
bool checkKerberosAuthentication(
const GSSAcceptorContext * gss_acceptor_context,
const std::vector<AuthenticationData> & authentication_methods,
const AuthenticationData & authentication_method,
const ExternalAuthenticators & external_authenticators)
{
auto kerberos_authentication_methods = getAuthenticationMethodsOfType(authentication_methods, {AuthenticationType::KERBEROS});
for (const auto & kerberos_authentication : kerberos_authentication_methods)
{
if (external_authenticators.checkKerberosCredentials(kerberos_authentication.getKerberosRealm(), *gss_acceptor_context))
{
return true;
}
}
return false;
return authentication_method.getType() == AuthenticationType::KERBEROS
&& external_authenticators.checkKerberosCredentials(authentication_method.getKerberosRealm(), *gss_acceptor_context);
}
bool checkMySQLAuthentication(
const MySQLNative41Credentials * mysql_credentials,
const std::vector<AuthenticationData> & authentication_methods)
const AuthenticationData & authentication_method)
{
auto mysql_authentication_methods = getAuthenticationMethodsOfType(
authentication_methods,
{AuthenticationType::PLAINTEXT_PASSWORD, AuthenticationType::DOUBLE_SHA1_PASSWORD});
for (const auto & mysql_authentication_method : mysql_authentication_methods)
switch (authentication_method.getType())
{
switch (mysql_authentication_method.getType())
{
case AuthenticationType::PLAINTEXT_PASSWORD:
if (checkPasswordPlainTextMySQL(
mysql_credentials->getScramble(), mysql_credentials->getScrambledPassword(), mysql_authentication_method.getPasswordHashBinary()))
{
return true;
}
break;
case AuthenticationType::DOUBLE_SHA1_PASSWORD:
if (checkPasswordDoubleSHA1MySQL(
mysql_credentials->getScramble(),
mysql_credentials->getScrambledPassword(),
mysql_authentication_method.getPasswordHashBinary()))
{
return true;
}
break;
default:
throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid MySQL authentication type");
}
case AuthenticationType::PLAINTEXT_PASSWORD:
return checkPasswordPlainTextMySQL(
mysql_credentials->getScramble(),
mysql_credentials->getScrambledPassword(),
authentication_method.getPasswordHashBinary());
case AuthenticationType::DOUBLE_SHA1_PASSWORD:
return checkPasswordDoubleSHA1MySQL(
mysql_credentials->getScramble(),
mysql_credentials->getScrambledPassword(),
authentication_method.getPasswordHashBinary());
default:
return false;
}
return false;
}
bool checkBasicAuthentication(
const BasicCredentials * basic_credentials,
const std::vector<AuthenticationData> & authentication_methods,
const AuthenticationData & authentication_method,
const ExternalAuthenticators & external_authenticators,
SettingsChanges & settings)
{
auto basic_credentials_authentication_methods = getAuthenticationMethodsOfType(
authentication_methods,
{AuthenticationType::NO_PASSWORD, AuthenticationType::PLAINTEXT_PASSWORD, AuthenticationType::SHA256_PASSWORD,
AuthenticationType::DOUBLE_SHA1_PASSWORD, AuthenticationType::LDAP, AuthenticationType::BCRYPT_PASSWORD,
AuthenticationType::HTTP});
for (const auto & auth_method : basic_credentials_authentication_methods)
switch (authentication_method.getType())
{
switch (auth_method.getType())
case AuthenticationType::NO_PASSWORD:
{
case AuthenticationType::NO_PASSWORD:
{
return true;
}
case AuthenticationType::PLAINTEXT_PASSWORD:
if (checkPasswordPlainText(basic_credentials->getPassword(), auth_method.getPasswordHashBinary()))
{
return true;
}
break;
case AuthenticationType::SHA256_PASSWORD:
if (checkPasswordSHA256(basic_credentials->getPassword(), auth_method.getPasswordHashBinary(), auth_method.getSalt()))
{
return true;
}
break;
case AuthenticationType::DOUBLE_SHA1_PASSWORD:
if (checkPasswordDoubleSHA1(basic_credentials->getPassword(), auth_method.getPasswordHashBinary()))
{
return true;
}
break;
case AuthenticationType::LDAP:
if (external_authenticators.checkLDAPCredentials(auth_method.getLDAPServerName(), *basic_credentials))
{
return true;
}
break;
case AuthenticationType::BCRYPT_PASSWORD:
if (checkPasswordBcrypt(basic_credentials->getPassword(), auth_method.getPasswordHashBinary()))
{
return true;
}
break;
case AuthenticationType::HTTP:
if (auth_method.getHTTPAuthenticationScheme() == HTTPAuthenticationScheme::BASIC)
{
return external_authenticators.checkHTTPBasicCredentials(
auth_method.getHTTPAuthenticationServerName(), *basic_credentials, settings);
}
break;
default:
throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid basic authentication type");
return true;
}
case AuthenticationType::PLAINTEXT_PASSWORD:
{
return checkPasswordPlainText(basic_credentials->getPassword(), authentication_method.getPasswordHashBinary());
}
case AuthenticationType::SHA256_PASSWORD:
{
return checkPasswordSHA256(
basic_credentials->getPassword(), authentication_method.getPasswordHashBinary(), authentication_method.getSalt());
}
case AuthenticationType::DOUBLE_SHA1_PASSWORD:
{
return checkPasswordDoubleSHA1(basic_credentials->getPassword(), authentication_method.getPasswordHashBinary());
}
case AuthenticationType::LDAP:
{
return external_authenticators.checkLDAPCredentials(authentication_method.getLDAPServerName(), *basic_credentials);
}
case AuthenticationType::BCRYPT_PASSWORD:
{
return checkPasswordBcrypt(basic_credentials->getPassword(), authentication_method.getPasswordHashBinary());
}
case AuthenticationType::HTTP:
{
if (authentication_method.getHTTPAuthenticationScheme() == HTTPAuthenticationScheme::BASIC)
{
return external_authenticators.checkHTTPBasicCredentials(
authentication_method.getHTTPAuthenticationServerName(), *basic_credentials, settings);
}
break;
}
default:
break;
}
return false;
}
bool checkSSLCertificateAuthentication(
const SSLCertificateCredentials * ssl_certificate_credentials,
const std::vector<AuthenticationData> & authentication_methods)
const AuthenticationData & authentication_method)
{
const auto ssl_certificate_authentication_methods = getAuthenticationMethodsOfType(authentication_methods, {AuthenticationType::SSL_CERTIFICATE});
for (const auto & auth_method : ssl_certificate_authentication_methods)
{
if (auth_method.getSSLCertificateCommonNames().contains(ssl_certificate_credentials->getCommonName()))
{
return true;
}
}
return false;
return AuthenticationType::SSL_CERTIFICATE == authentication_method.getType()
&& authentication_method.getSSLCertificateCommonNames().contains(ssl_certificate_credentials->getCommonName());
}
#if USE_SSH
bool checkSshAuthentication(
const SshCredentials * ssh_credentials,
const std::vector<AuthenticationData> & authentication_methods)
const AuthenticationData & authentication_method)
{
const auto ssh_authentication_methods = getAuthenticationMethodsOfType(authentication_methods, {AuthenticationType::SSH_KEY});
for (const auto & auth_method : ssh_authentication_methods)
{
if (checkSshSignature(auth_method.getSSHKeys(), ssh_credentials->getSignature(), ssh_credentials->getOriginal()))
{
return true;
}
}
return false;
return AuthenticationType::SSH_KEY == authentication_method.getType()
&& checkSshSignature(authentication_method.getSSHKeys(), ssh_credentials->getSignature(), ssh_credentials->getOriginal());
}
#endif
}
bool Authentication::areCredentialsValid(
const Credentials & credentials,
const std::vector<AuthenticationData> & authentication_methods,
const AuthenticationData & authentication_method,
const ExternalAuthenticators & external_authenticators,
SettingsChanges & settings)
{
@ -261,28 +192,28 @@ bool Authentication::areCredentialsValid(
if (const auto * gss_acceptor_context = typeid_cast<const GSSAcceptorContext *>(&credentials))
{
return checkKerberosAuthentication(gss_acceptor_context, authentication_methods, external_authenticators);
return checkKerberosAuthentication(gss_acceptor_context, authentication_method, external_authenticators);
}
if (const auto * mysql_credentials = typeid_cast<const MySQLNative41Credentials *>(&credentials))
{
return checkMySQLAuthentication(mysql_credentials, authentication_methods);
return checkMySQLAuthentication(mysql_credentials, authentication_method);
}
if (const auto * basic_credentials = typeid_cast<const BasicCredentials *>(&credentials))
{
return checkBasicAuthentication(basic_credentials, authentication_methods, external_authenticators, settings);
return checkBasicAuthentication(basic_credentials, authentication_method, external_authenticators, settings);
}
if (const auto * ssl_certificate_credentials = typeid_cast<const SSLCertificateCredentials *>(&credentials))
{
return checkSSLCertificateAuthentication(ssl_certificate_credentials, authentication_methods);
return checkSSLCertificateAuthentication(ssl_certificate_credentials, authentication_method);
}
#if USE_SSH
if (const auto * ssh_credentials = typeid_cast<const SshCredentials *>(&credentials))
{
return checkSshAuthentication(ssh_credentials, authentication_methods);
return checkSshAuthentication(ssh_credentials, authentication_method);
}
#endif

View File

@ -24,7 +24,7 @@ struct Authentication
/// returned by the authentication server
static bool areCredentialsValid(
const Credentials & credentials,
const std::vector<AuthenticationData> & authentication_methods,
const AuthenticationData & authentication_method,
const ExternalAuthenticators & external_authenticators,
SettingsChanges & settings);

View File

@ -87,7 +87,6 @@ TEST_P(ParserTest, parseQuery)
{
if (input_text.starts_with("ATTACH"))
{
// todo arthur
auto salt = (dynamic_cast<const ASTCreateUserQuery *>(ast.get())->auth_data.back())->getSalt().value_or("");
EXPECT_TRUE(re2::RE2::FullMatch(salt, expected_ast));
}