Fix crashes with incorrect query parameters

This commit is contained in:
Nikolay Degterinsky 2023-04-11 14:19:45 +00:00
parent f34b304707
commit 0621222737
4 changed files with 46 additions and 33 deletions

View File

@ -11,6 +11,7 @@
#include <Interpreters/Context.h>
#include <Interpreters/executeDDLQueryOnCluster.h>
#include <Storages/checkAndGetLiteralArgument.h>
#include <Interpreters/evaluateConstantExpression.h>
#include <boost/range/algorithm/copy.hpp>
#include "config.h"
@ -37,6 +38,11 @@ namespace
if (query.type && *query.type == AuthenticationType::NO_PASSWORD)
return AuthenticationData();
size_t args_size = query.children.size();
ASTs args(args_size);
for (size_t i = 0; i < args_size; ++i)
args[i] = evaluateConstantExpressionAsLiteral(query.children[i], context);
if (query.expect_password)
{
if (!query.type && !context)
@ -47,7 +53,7 @@ namespace
/// NOTE: We will also extract bcrypt workfactor from context
String value = checkAndGetLiteralArgument<String>(query.children[0], "password");
String value = checkAndGetLiteralArgument<String>(args[0], "password");
AuthenticationType current_type;
@ -97,33 +103,33 @@ namespace
if (query.expect_hash)
{
String value = checkAndGetLiteralArgument<String>(query.children[0], "hash");
String value = checkAndGetLiteralArgument<String>(args[0], "hash");
auth_data.setPasswordHashHex(value);
if (*query.type == AuthenticationType::SHA256_PASSWORD && query.children.size() == 2)
if (*query.type == AuthenticationType::SHA256_PASSWORD && args_size == 2)
{
String parsed_salt = checkAndGetLiteralArgument<String>(query.children[1], "salt");
String parsed_salt = checkAndGetLiteralArgument<String>(args[1], "salt");
auth_data.setSalt(parsed_salt);
}
}
else if (query.expect_ldap_server_name)
{
String value = checkAndGetLiteralArgument<String>(query.children[0], "ldap_server_name");
String value = checkAndGetLiteralArgument<String>(args[0], "ldap_server_name");
auth_data.setLDAPServerName(value);
}
else if (query.expect_kerberos_realm)
{
if (!query.children.empty())
if (!args.empty())
{
String value = checkAndGetLiteralArgument<String>(query.children[0], "kerberos_realm");
String value = checkAndGetLiteralArgument<String>(args[0], "kerberos_realm");
auth_data.setKerberosRealm(value);
}
}
else if (query.expect_common_names)
{
boost::container::flat_set<String> common_names;
for (const auto & ast_child : query.children[0]->children)
common_names.insert(checkAndGetLiteralArgument<String>(ast_child, "common_name"));
for (const auto & arg : args)
common_names.insert(checkAndGetLiteralArgument<String>(arg, "common_name"));
auth_data.setSSLCertificateCommonNames(std::move(common_names));
}

View File

@ -92,12 +92,9 @@ namespace
case AuthenticationType::SSL_CERTIFICATE:
{
node->expect_common_names = true;
auto list = std::make_shared<ASTExpressionList>();
for (const auto & name : auth_data.getSSLCertificateCommonNames())
list->children.push_back(std::make_shared<ASTLiteral>(name));
node->children.push_back(std::make_shared<ASTLiteral>(name));
node->children.push_back(std::move(list));
break;
}

View File

@ -31,7 +31,7 @@ std::optional<String> ASTAuthenticationData::getSalt() const
{
if (type && *type == AuthenticationType::SHA256_PASSWORD && children.size() == 2)
{
if (const auto * salt = children[0]->as<const ASTLiteral>())
if (const auto * salt = children[1]->as<const ASTLiteral>())
{
return salt->value.safeGet<String>();
}
@ -51,10 +51,10 @@ void ASTAuthenticationData::formatImpl(const FormatSettings & settings, FormatSt
String auth_type_name;
String prefix; /// "BY" or "SERVER" or "REALM"
ASTPtr password; /// either a password or hash
ASTPtr salt;
ASTPtr parameter;
ASTPtr parameters;
bool password = false; /// either a password or hash
bool salt = false;
bool parameter = false;
bool parameters = false;
if (type)
{
@ -65,7 +65,7 @@ void ASTAuthenticationData::formatImpl(const FormatSettings & settings, FormatSt
case AuthenticationType::PLAINTEXT_PASSWORD:
{
prefix = "BY";
password = children[0];
password = true;
break;
}
case AuthenticationType::SHA256_PASSWORD:
@ -74,9 +74,9 @@ void ASTAuthenticationData::formatImpl(const FormatSettings & settings, FormatSt
auth_type_name = "sha256_hash";
prefix = "BY";
password = children[0];
password = true;
if (children.size() == 2)
salt = children[1];
salt = true;
break;
}
case AuthenticationType::DOUBLE_SHA1_PASSWORD:
@ -85,13 +85,13 @@ void ASTAuthenticationData::formatImpl(const FormatSettings & settings, FormatSt
auth_type_name = "double_sha1_hash";
prefix = "BY";
password = children[0];
password = true;
break;
}
case AuthenticationType::LDAP:
{
prefix = "SERVER";
parameter = children[0];
parameter = true;
break;
}
case AuthenticationType::KERBEROS:
@ -99,14 +99,14 @@ void ASTAuthenticationData::formatImpl(const FormatSettings & settings, FormatSt
if (!children.empty())
{
prefix = "REALM";
parameter = children[0];
parameter = true;
}
break;
}
case AuthenticationType::SSL_CERTIFICATE:
{
prefix = "CN";
parameters = children[0];
parameters = true;
break;
}
case AuthenticationType::NO_PASSWORD: [[fallthrough]];
@ -118,14 +118,14 @@ void ASTAuthenticationData::formatImpl(const FormatSettings & settings, FormatSt
{
/// Default password type
prefix = "BY";
password = children[0];
password = true;
}
if (password && !settings.show_secrets)
{
prefix = "";
password.reset();
salt.reset();
password = false;
salt = false;
if (type)
auth_type_name = AuthenticationTypeInfo::get(*type).name;
}
@ -146,24 +146,30 @@ void ASTAuthenticationData::formatImpl(const FormatSettings & settings, FormatSt
if (password)
{
settings.ostr << " ";
password->format(settings);
children[0]->format(settings);
}
if (salt)
{
settings.ostr << " SALT ";
salt->format(settings);
children[1]->format(settings);
}
if (parameter)
{
settings.ostr << " ";
parameter->format(settings);
children[0]->format(settings);
}
else if (parameters)
{
settings.ostr << " ";
parameters->format(settings);
bool need_comma = false;
for (const auto & child : children)
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
child->format(settings);
}
}
}

View File

@ -122,6 +122,7 @@ namespace
ASTPtr value;
ASTPtr parsed_salt;
ASTPtr common_names;
if (expect_password || expect_hash)
{
if (!ParserKeyword{"BY"}.ignore(pos, expected) || !ParserStringAndSubstitution{}.parse(pos, value, expected))
@ -154,7 +155,7 @@ namespace
if (!ParserKeyword{"CN"}.ignore(pos, expected))
return false;
if (!ParserList{std::make_unique<ParserStringAndSubstitution>(), std::make_unique<ParserToken>(TokenType::Comma), false}.parse(pos, value, expected))
if (!ParserList{std::make_unique<ParserStringAndSubstitution>(), std::make_unique<ParserToken>(TokenType::Comma), false}.parse(pos, common_names, expected))
return false;
}
@ -173,6 +174,9 @@ namespace
if (parsed_salt)
auth_data->children.push_back(std::move(parsed_salt));
if (common_names)
auth_data->children = std::move(common_names->children);
return true;
});
}