diff --git a/src/IO/S3/Credentials.cpp b/src/IO/S3/Credentials.cpp index 60a16395619..1bc37e2b6b3 100644 --- a/src/IO/S3/Credentials.cpp +++ b/src/IO/S3/Credentials.cpp @@ -1,4 +1,5 @@ #include +#include #include #include "Interpreters/Context.h" @@ -48,6 +49,8 @@ bool areCredentialsEmptyOrExpired(const Aws::Auth::AWSCredentials & credentials, return now >= credentials.GetExpiration() - std::chrono::seconds(expiration_window_seconds); } +const char SSO_CREDENTIALS_PROVIDER_LOG_TAG[] = "SSOCredentialsProvider"; + } AWSEC2MetadataClient::AWSEC2MetadataClient(const Aws::Client::ClientConfiguration & client_configuration, const char * endpoint_) @@ -454,176 +457,137 @@ void AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider::refreshIfExpired() Reload(); } -class SSOCredentialsProvider : public Aws::Auth::AWSCredentialsProvider + +SSOCredentialsProvider::SSOCredentialsProvider(DB::S3::PocoHTTPClientConfiguration aws_client_configuration_, uint64_t expiration_window_seconds_) + : profile_to_use(Aws::Auth::GetConfigProfileName()) + , aws_client_configuration(std::move(aws_client_configuration_)) + , expiration_window_seconds(expiration_window_seconds_) + , logger(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG)) { -public: - SSOCredentialsProvider(); - explicit SSOCredentialsProvider(const Aws::String& profile); - /** - * Retrieves the credentials if found, otherwise returns empty credential set. - */ - Aws::Auth::AWSCredentials GetAWSCredentials() override; - -private: - Aws::UniquePtr m_client; - Aws::Auth::AWSCredentials m_credentials; - - // Profile description variables - Aws::String m_profileToUse; - - // The AWS account ID that temporary AWS credentials are resolved for. - Aws::String m_ssoAccountId; - // The AWS region where the SSO directory for the given sso_start_url is hosted. - // This is independent of the general region configuration and MUST NOT be conflated. - Aws::String m_ssoRegion; - // The expiration time of the accessToken. - Aws::Utils::DateTime m_expiresAt; - // The SSO Token Provider - Aws::Auth::SSOBearerTokenProvider m_bearerTokenProvider; - - void Reload() override; - void RefreshIfExpired(); - Aws::String LoadAccessTokenFile(const Aws::String& ssoAccessTokenPath); -}; - -static const char SSO_CREDENTIALS_PROVIDER_LOG_TAG[] = "SSOCredentialsProvider"; - -SSOCredentialsProvider::SSOCredentialsProvider() : m_profileToUse(Aws::Auth::GetConfigProfileName()) -{ - LOG_INFO(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), "Setting sso credentials provider to read config from {}", m_profileToUse); -} - -SSOCredentialsProvider::SSOCredentialsProvider(const Aws::String& profile) : m_profileToUse(profile), - m_bearerTokenProvider(profile) -{ - LOG_INFO(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), "Setting sso credentials provider to read config from {}", m_profileToUse); + LOG_INFO(logger, "Setting sso credentials provider to read config from {}", profile_to_use); } Aws::Auth::AWSCredentials SSOCredentialsProvider::GetAWSCredentials() { - RefreshIfExpired(); + refreshIfExpired(); Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock); - return m_credentials; + return credentials; } void SSOCredentialsProvider::Reload() { - auto profile = Aws::Config::GetCachedConfigProfile(m_profileToUse); - const auto accessToken = [&]() -> Aws::String { + auto profile = Aws::Config::GetCachedConfigProfile(profile_to_use); + const auto access_token = [&] + { // If we have an SSO Session set, use the refreshed token. - if (profile.IsSsoSessionSet()) { - m_ssoRegion = profile.GetSsoSession().GetSsoRegion(); - auto token = m_bearerTokenProvider.GetAWSBearerToken(); - m_expiresAt = token.GetExpiration(); + if (profile.IsSsoSessionSet()) + { + sso_region = profile.GetSsoSession().GetSsoRegion(); + auto token = bearer_token_provider.GetAWSBearerToken(); + expires_at = token.GetExpiration(); return token.GetToken(); } - Aws::String hashedStartUrl = Aws::Utils::HashingUtils::HexEncode(Aws::Utils::HashingUtils::CalculateSHA1(profile.GetSsoStartUrl())); - auto profileDirectory = Aws::Auth::ProfileConfigFileAWSCredentialsProvider::GetProfileDirectory(); - Aws::StringStream ssToken; - ssToken << profileDirectory; - ssToken << Aws::FileSystem::PATH_DELIM << "sso" << Aws::FileSystem::PATH_DELIM << "cache" << Aws::FileSystem::PATH_DELIM << hashedStartUrl << ".json"; - auto ssoTokenPath = ssToken.str(); - LOG_INFO(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), "Loading token from: {}", ssoTokenPath); - m_ssoRegion = profile.GetSsoRegion(); - return LoadAccessTokenFile(ssoTokenPath); + + Aws::String hashed_start_url = Aws::Utils::HashingUtils::HexEncode(Aws::Utils::HashingUtils::CalculateSHA1(profile.GetSsoStartUrl())); + auto profile_directory = Aws::Auth::ProfileConfigFileAWSCredentialsProvider::GetProfileDirectory(); + Aws::StringStream ss_token; + ss_token << profile_directory; + ss_token << Aws::FileSystem::PATH_DELIM << "sso" << Aws::FileSystem::PATH_DELIM << "cache" << Aws::FileSystem::PATH_DELIM << hashed_start_url << ".json"; + auto sso_token_path = ss_token.str(); + LOG_INFO(logger, "Loading token from: {}", sso_token_path); + sso_region = profile.GetSsoRegion(); + return loadAccessTokenFile(sso_token_path); }(); - if (accessToken.empty()) { - LOG_INFO(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), "Access token for SSO not available"); + + if (access_token.empty()) + { + LOG_TRACE(logger, "Access token for SSO not available"); return; } - if (m_expiresAt < Aws::Utils::DateTime::Now()) { - LOG_INFO(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), "Cached Token expired at {}", m_expiresAt.ToGmtString(Aws::Utils::DateFormat::ISO_8601)); + if (expires_at < Aws::Utils::DateTime::Now()) + { + LOG_TRACE(logger, "Cached Token expired at {}", expires_at.ToGmtString(Aws::Utils::DateFormat::ISO_8601)); return; } + Aws::Internal::SSOCredentialsClient::SSOGetRoleCredentialsRequest request; request.m_ssoAccountId = profile.GetSsoAccountId(); request.m_ssoRoleName = profile.GetSsoRoleName(); - request.m_accessToken = accessToken; + request.m_accessToken = access_token; auto context = DB::Context::getGlobalContextInstance(); - auto config = ClientFactory::instance().createClientConfiguration( - m_ssoRegion, - context->getRemoteHostFilter(), - static_cast(context->getGlobalContext()->getSettingsRef().s3_max_redirects), - context->getGlobalContext()->getSettingsRef().enable_s3_requests_logging, - /* for_disk_s3 = */ true, - {}, - {}, - "HTTPS" - ); - config.scheme = Aws::Http::Scheme::HTTPS; - config.region = m_ssoRegion; - LOG_INFO(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), "Passing config to client for region: {}", m_ssoRegion); + aws_client_configuration.scheme = Aws::Http::Scheme::HTTPS; + aws_client_configuration.region = sso_region; + LOG_TRACE(logger, "Passing config to client for region: {}", sso_region); - Aws::Vector retryableErrors; - retryableErrors.push_back("TooManyRequestsException"); + Aws::Vector retryable_errors; + retryable_errors.push_back("TooManyRequestsException"); - config.retryStrategy = Aws::MakeShared(SSO_CREDENTIALS_PROVIDER_LOG_TAG, retryableErrors, 3/*maxRetries*/); - m_client = Aws::MakeUnique(SSO_CREDENTIALS_PROVIDER_LOG_TAG, config); + aws_client_configuration.retryStrategy = Aws::MakeShared(SSO_CREDENTIALS_PROVIDER_LOG_TAG, retryable_errors, /*maxRetries=*/3); + client = Aws::MakeUnique(SSO_CREDENTIALS_PROVIDER_LOG_TAG, aws_client_configuration); - LOG_INFO(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), "Requesting credentials with AWS_ACCESS_KEY: {}", m_ssoAccountId); - auto result = m_client->GetSSOCredentials(request); - LOG_INFO(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), "Successfully retrieved credentials with AWS_ACCESS_KEY: {}", result.creds.GetAWSAccessKeyId()); + LOG_TRACE(logger, "Requesting credentials with AWS_ACCESS_KEY: {}", sso_account_id); + auto result = client->GetSSOCredentials(request); + LOG_TRACE(logger, "Successfully retrieved credentials with AWS_ACCESS_KEY: {}", result.creds.GetAWSAccessKeyId()); - m_credentials = result.creds; + credentials = result.creds; } -void SSOCredentialsProvider::RefreshIfExpired() +void SSOCredentialsProvider::refreshIfExpired() { Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock); - if (!m_credentials.IsExpiredOrEmpty()) - { + if (!areCredentialsEmptyOrExpired(credentials, expiration_window_seconds)) return; - } guard.UpgradeToWriterLock(); - if (!m_credentials.IsExpiredOrEmpty()) // double-checked lock to avoid refreshing twice - { + + if (!areCredentialsEmptyOrExpired(credentials, expiration_window_seconds)) // double-checked lock to avoid refreshing twice return; - } Reload(); } -Aws::String SSOCredentialsProvider::LoadAccessTokenFile(const Aws::String& ssoAccessTokenPath) +Aws::String SSOCredentialsProvider::loadAccessTokenFile(const Aws::String & sso_access_token_path) { - LOG_INFO(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), "Preparing to load token from: {}", ssoAccessTokenPath); + LOG_TRACE(logger, "Preparing to load token from: {}", sso_access_token_path); - Aws::IFStream inputFile(ssoAccessTokenPath.c_str()); - if(inputFile) + Aws::IFStream input_file(sso_access_token_path.c_str()); + + if (input_file) { - LOG_INFO(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), "Reading content from token file: {}", ssoAccessTokenPath); + LOG_TRACE(logger, "Reading content from token file: {}", sso_access_token_path); - Aws::Utils::Json::JsonValue tokenDoc(inputFile); - if (!tokenDoc.WasParseSuccessful()) + Aws::Utils::Json::JsonValue token_doc(input_file); + if (!token_doc.WasParseSuccessful()) { - LOG_INFO(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), "Failed to parse token file: {}", ssoAccessTokenPath); + LOG_TRACE(logger, "Failed to parse token file: {}", sso_access_token_path); return ""; } - Aws::Utils::Json::JsonView tokenView(tokenDoc); - Aws::String tmpAccessToken, expirationStr; - tmpAccessToken = tokenView.GetString("accessToken"); - expirationStr = tokenView.GetString("expiresAt"); - Aws::Utils::DateTime expiration(expirationStr, Aws::Utils::DateFormat::ISO_8601); + Aws::Utils::Json::JsonView token_view(token_doc); + Aws::String tmp_access_token, expiration_str; + tmp_access_token = token_view.GetString("accessToken"); + expiration_str = token_view.GetString("expiresAt"); + Aws::Utils::DateTime expiration(expiration_str, Aws::Utils::DateFormat::ISO_8601); - LOG_INFO(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), "Token cache file contains accessToken [{}], expiration [{}]", tmpAccessToken, expirationStr); + LOG_TRACE(logger, "Token cache file contains accessToken [{}], expiration [{}]", tmp_access_token, expiration_str); - if (tmpAccessToken.empty() || !expiration.WasParseSuccessful()) { - LOG_INFO(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), R"(The SSO session associated with this profile has expired or is otherwise invalid. To refresh this SSO session run aws sso login with the corresponding profile.)"); - LOG_INFO( - &Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), + if (tmp_access_token.empty() || !expiration.WasParseSuccessful()) + { + LOG_TRACE(logger, R"(The SSO session associated with this profile has expired or is otherwise invalid. To refresh this SSO session run aws sso login with the corresponding profile.)"); + LOG_TRACE( + logger, "Token cache file failed because {}{}", - (tmpAccessToken.empty() ? "AccessToken was empty " : ""), + (tmp_access_token.empty() ? "AccessToken was empty " : ""), (!expiration.WasParseSuccessful() ? "failed to parse expiration" : "")); return ""; } - m_expiresAt = expiration; - return tmpAccessToken; + expires_at = expiration; + return tmp_access_token; } else { - LOG_INFO(&Poco::Logger::get(SSO_CREDENTIALS_PROVIDER_LOG_TAG), "Unable to open token file on path: {}", ssoAccessTokenPath); + LOG_TRACE(logger, "Unable to open token file on path: {}", sso_access_token_path); return ""; } } @@ -673,6 +637,18 @@ S3CredentialsProviderChain::S3CredentialsProviderChain( AddProvider(std::make_shared()); + { + DB::S3::PocoHTTPClientConfiguration aws_client_configuration = DB::S3::ClientFactory::instance().createClientConfiguration( + configuration.region, + configuration.remote_host_filter, + configuration.s3_max_redirects, + configuration.enable_s3_requests_logging, + configuration.for_disk_s3, + configuration.get_request_throttler, + configuration.put_request_throttler); + AddProvider(std::make_shared( + std::move(aws_client_configuration), credentials_configuration.expiration_window_seconds)); + } /// ECS TaskRole Credentials only available when ENVIRONMENT VARIABLE is set. const auto relative_uri = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI); @@ -739,8 +715,6 @@ S3CredentialsProviderChain::S3CredentialsProviderChain( AddProvider(std::make_shared(config_loader)); LOG_INFO(logger, "Added EC2 metadata service credentials provider to the provider chain."); } - - AddProvider(std::make_shared()); } /// Quite verbose provider (argues if file with credentials doesn't exist) so iut's the last one diff --git a/src/IO/S3/Credentials.h b/src/IO/S3/Credentials.h index 324b750c683..0243e8e4986 100644 --- a/src/IO/S3/Credentials.h +++ b/src/IO/S3/Credentials.h @@ -8,6 +8,7 @@ # include # include # include +# include # include @@ -124,6 +125,39 @@ private: uint64_t expiration_window_seconds; }; +class SSOCredentialsProvider : public Aws::Auth::AWSCredentialsProvider +{ +public: + SSOCredentialsProvider(DB::S3::PocoHTTPClientConfiguration aws_client_configuration_, uint64_t expiration_window_seconds_); + + Aws::Auth::AWSCredentials GetAWSCredentials() override; + +private: + Aws::UniquePtr client; + Aws::Auth::AWSCredentials credentials; + + // Profile description variables + Aws::String profile_to_use; + + // The AWS account ID that temporary AWS credentials are resolved for. + Aws::String sso_account_id; + // The AWS region where the SSO directory for the given sso_start_url is hosted. + // This is independent of the general region configuration and MUST NOT be conflated. + Aws::String sso_region; + // The expiration time of the accessToken. + Aws::Utils::DateTime expires_at; + // The SSO Token Provider + Aws::Auth::SSOBearerTokenProvider bearer_token_provider; + + DB::S3::PocoHTTPClientConfiguration aws_client_configuration; + uint64_t expiration_window_seconds; + Poco::Logger * logger; + + void Reload() override; + void refreshIfExpired(); + Aws::String loadAccessTokenFile(const Aws::String & sso_access_token_path); +}; + struct CredentialsConfiguration { bool use_environment_credentials = false; diff --git a/src/IO/S3/PocoHTTPClientFactory.cpp b/src/IO/S3/PocoHTTPClientFactory.cpp index ade72a3dea6..9dd52a263b0 100644 --- a/src/IO/S3/PocoHTTPClientFactory.cpp +++ b/src/IO/S3/PocoHTTPClientFactory.cpp @@ -1,4 +1,3 @@ -#include "IO/S3/Client.h" #include "config.h" #if USE_AWS_S3