From 0621222737ef596e17d2e297ef43da49560bf84d Mon Sep 17 00:00:00 2001 From: Nikolay Degterinsky Date: Tue, 11 Apr 2023 14:19:45 +0000 Subject: [PATCH] Fix crashes with incorrect query parameters --- .../Access/InterpreterCreateUserQuery.cpp | 24 ++++++---- ...InterpreterShowCreateAccessEntityQuery.cpp | 5 +-- src/Parsers/Access/ASTAuthenticationData.cpp | 44 +++++++++++-------- src/Parsers/Access/ParserCreateUserQuery.cpp | 6 ++- 4 files changed, 46 insertions(+), 33 deletions(-) diff --git a/src/Interpreters/Access/InterpreterCreateUserQuery.cpp b/src/Interpreters/Access/InterpreterCreateUserQuery.cpp index 222ab2081ca..8924d7d887c 100644 --- a/src/Interpreters/Access/InterpreterCreateUserQuery.cpp +++ b/src/Interpreters/Access/InterpreterCreateUserQuery.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include "config.h" @@ -37,6 +38,11 @@ namespace if (query.type && *query.type == AuthenticationType::NO_PASSWORD) return AuthenticationData(); + size_t args_size = query.children.size(); + ASTs args(args_size); + for (size_t i = 0; i < args_size; ++i) + args[i] = evaluateConstantExpressionAsLiteral(query.children[i], context); + if (query.expect_password) { if (!query.type && !context) @@ -47,7 +53,7 @@ namespace /// NOTE: We will also extract bcrypt workfactor from context - String value = checkAndGetLiteralArgument(query.children[0], "password"); + String value = checkAndGetLiteralArgument(args[0], "password"); AuthenticationType current_type; @@ -97,33 +103,33 @@ namespace if (query.expect_hash) { - String value = checkAndGetLiteralArgument(query.children[0], "hash"); + String value = checkAndGetLiteralArgument(args[0], "hash"); auth_data.setPasswordHashHex(value); - if (*query.type == AuthenticationType::SHA256_PASSWORD && query.children.size() == 2) + if (*query.type == AuthenticationType::SHA256_PASSWORD && args_size == 2) { - String parsed_salt = checkAndGetLiteralArgument(query.children[1], "salt"); + String parsed_salt = checkAndGetLiteralArgument(args[1], "salt"); auth_data.setSalt(parsed_salt); } } else if (query.expect_ldap_server_name) { - String value = checkAndGetLiteralArgument(query.children[0], "ldap_server_name"); + String value = checkAndGetLiteralArgument(args[0], "ldap_server_name"); auth_data.setLDAPServerName(value); } else if (query.expect_kerberos_realm) { - if (!query.children.empty()) + if (!args.empty()) { - String value = checkAndGetLiteralArgument(query.children[0], "kerberos_realm"); + String value = checkAndGetLiteralArgument(args[0], "kerberos_realm"); auth_data.setKerberosRealm(value); } } else if (query.expect_common_names) { boost::container::flat_set common_names; - for (const auto & ast_child : query.children[0]->children) - common_names.insert(checkAndGetLiteralArgument(ast_child, "common_name")); + for (const auto & arg : args) + common_names.insert(checkAndGetLiteralArgument(arg, "common_name")); auth_data.setSSLCertificateCommonNames(std::move(common_names)); } diff --git a/src/Interpreters/Access/InterpreterShowCreateAccessEntityQuery.cpp b/src/Interpreters/Access/InterpreterShowCreateAccessEntityQuery.cpp index de7df523a3e..0878c9d3207 100644 --- a/src/Interpreters/Access/InterpreterShowCreateAccessEntityQuery.cpp +++ b/src/Interpreters/Access/InterpreterShowCreateAccessEntityQuery.cpp @@ -92,12 +92,9 @@ namespace case AuthenticationType::SSL_CERTIFICATE: { node->expect_common_names = true; - - auto list = std::make_shared(); for (const auto & name : auth_data.getSSLCertificateCommonNames()) - list->children.push_back(std::make_shared(name)); + node->children.push_back(std::make_shared(name)); - node->children.push_back(std::move(list)); break; } diff --git a/src/Parsers/Access/ASTAuthenticationData.cpp b/src/Parsers/Access/ASTAuthenticationData.cpp index aa4642c10cb..b214be07f42 100644 --- a/src/Parsers/Access/ASTAuthenticationData.cpp +++ b/src/Parsers/Access/ASTAuthenticationData.cpp @@ -31,7 +31,7 @@ std::optional ASTAuthenticationData::getSalt() const { if (type && *type == AuthenticationType::SHA256_PASSWORD && children.size() == 2) { - if (const auto * salt = children[0]->as()) + if (const auto * salt = children[1]->as()) { return salt->value.safeGet(); } @@ -51,10 +51,10 @@ void ASTAuthenticationData::formatImpl(const FormatSettings & settings, FormatSt String auth_type_name; String prefix; /// "BY" or "SERVER" or "REALM" - ASTPtr password; /// either a password or hash - ASTPtr salt; - ASTPtr parameter; - ASTPtr parameters; + bool password = false; /// either a password or hash + bool salt = false; + bool parameter = false; + bool parameters = false; if (type) { @@ -65,7 +65,7 @@ void ASTAuthenticationData::formatImpl(const FormatSettings & settings, FormatSt case AuthenticationType::PLAINTEXT_PASSWORD: { prefix = "BY"; - password = children[0]; + password = true; break; } case AuthenticationType::SHA256_PASSWORD: @@ -74,9 +74,9 @@ void ASTAuthenticationData::formatImpl(const FormatSettings & settings, FormatSt auth_type_name = "sha256_hash"; prefix = "BY"; - password = children[0]; + password = true; if (children.size() == 2) - salt = children[1]; + salt = true; break; } case AuthenticationType::DOUBLE_SHA1_PASSWORD: @@ -85,13 +85,13 @@ void ASTAuthenticationData::formatImpl(const FormatSettings & settings, FormatSt auth_type_name = "double_sha1_hash"; prefix = "BY"; - password = children[0]; + password = true; break; } case AuthenticationType::LDAP: { prefix = "SERVER"; - parameter = children[0]; + parameter = true; break; } case AuthenticationType::KERBEROS: @@ -99,14 +99,14 @@ void ASTAuthenticationData::formatImpl(const FormatSettings & settings, FormatSt if (!children.empty()) { prefix = "REALM"; - parameter = children[0]; + parameter = true; } break; } case AuthenticationType::SSL_CERTIFICATE: { prefix = "CN"; - parameters = children[0]; + parameters = true; break; } case AuthenticationType::NO_PASSWORD: [[fallthrough]]; @@ -118,14 +118,14 @@ void ASTAuthenticationData::formatImpl(const FormatSettings & settings, FormatSt { /// Default password type prefix = "BY"; - password = children[0]; + password = true; } if (password && !settings.show_secrets) { prefix = ""; - password.reset(); - salt.reset(); + password = false; + salt = false; if (type) auth_type_name = AuthenticationTypeInfo::get(*type).name; } @@ -146,24 +146,30 @@ void ASTAuthenticationData::formatImpl(const FormatSettings & settings, FormatSt if (password) { settings.ostr << " "; - password->format(settings); + children[0]->format(settings); } if (salt) { settings.ostr << " SALT "; - salt->format(settings); + children[1]->format(settings); } if (parameter) { settings.ostr << " "; - parameter->format(settings); + children[0]->format(settings); } else if (parameters) { settings.ostr << " "; - parameters->format(settings); + bool need_comma = false; + for (const auto & child : children) + { + if (std::exchange(need_comma, true)) + settings.ostr << ", "; + child->format(settings); + } } } diff --git a/src/Parsers/Access/ParserCreateUserQuery.cpp b/src/Parsers/Access/ParserCreateUserQuery.cpp index 367e137b6c8..4f26c624e9b 100644 --- a/src/Parsers/Access/ParserCreateUserQuery.cpp +++ b/src/Parsers/Access/ParserCreateUserQuery.cpp @@ -122,6 +122,7 @@ namespace ASTPtr value; ASTPtr parsed_salt; + ASTPtr common_names; if (expect_password || expect_hash) { if (!ParserKeyword{"BY"}.ignore(pos, expected) || !ParserStringAndSubstitution{}.parse(pos, value, expected)) @@ -154,7 +155,7 @@ namespace if (!ParserKeyword{"CN"}.ignore(pos, expected)) return false; - if (!ParserList{std::make_unique(), std::make_unique(TokenType::Comma), false}.parse(pos, value, expected)) + if (!ParserList{std::make_unique(), std::make_unique(TokenType::Comma), false}.parse(pos, common_names, expected)) return false; } @@ -173,6 +174,9 @@ namespace if (parsed_salt) auth_data->children.push_back(std::move(parsed_salt)); + if (common_names) + auth_data->children = std::move(common_names->children); + return true; }); }