Attempt to resolve nullptr in STS credentials provider for S3.

This commit is contained in:
Vladimir Chebotarev 2021-11-08 10:59:43 +03:00
parent 11b70a285c
commit 5e9710a26e

View File

@ -14,12 +14,14 @@
# include <aws/core/auth/AWSCredentialsProviderChain.h> # include <aws/core/auth/AWSCredentialsProviderChain.h>
# include <aws/core/auth/STSCredentialsProvider.h> # include <aws/core/auth/STSCredentialsProvider.h>
# include <aws/core/client/DefaultRetryStrategy.h> # include <aws/core/client/DefaultRetryStrategy.h>
# include <aws/core/client/SpecifiedRetryableErrorsRetryStrategy.h>
# include <aws/core/platform/Environment.h> # include <aws/core/platform/Environment.h>
# include <aws/core/platform/OSVersionInfo.h> # include <aws/core/platform/OSVersionInfo.h>
# include <aws/core/utils/json/JsonSerializer.h> # include <aws/core/utils/json/JsonSerializer.h>
# include <aws/core/utils/logging/LogMacros.h> # include <aws/core/utils/logging/LogMacros.h>
# include <aws/core/utils/logging/LogSystemInterface.h> # include <aws/core/utils/logging/LogSystemInterface.h>
# include <aws/core/utils/HashingUtils.h> # include <aws/core/utils/HashingUtils.h>
# include <aws/core/utils/UUID.h>
# include <aws/core/http/HttpClientFactory.h> # include <aws/core/http/HttpClientFactory.h>
# include <aws/s3/S3Client.h> # include <aws/s3/S3Client.h>
@ -30,6 +32,8 @@
# include <boost/algorithm/string/case_conv.hpp> # include <boost/algorithm/string/case_conv.hpp>
# include <base/logger_useful.h> # include <base/logger_useful.h>
# include <fstream>
namespace namespace
{ {
@ -361,6 +365,156 @@ private:
Poco::Logger * logger; Poco::Logger * logger;
}; };
class AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider : public Aws::Auth::AWSCredentialsProvider
{
/// See STSAssumeRoleWebIdentityCredentialsProvider.
public:
AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider(DB::S3::PocoHTTPClientConfiguration & aws_client_configuration)
: logger(&Poco::Logger::get("AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider"))
{
// check environment variables
String tmp_region = Aws::Environment::GetEnv("AWS_DEFAULT_REGION");
role_arn = Aws::Environment::GetEnv("AWS_ROLE_ARN");
token_file = Aws::Environment::GetEnv("AWS_WEB_IDENTITY_TOKEN_FILE");
session_name = Aws::Environment::GetEnv("AWS_ROLE_SESSION_NAME");
// check profile_config if either m_roleArn or m_tokenFile is not loaded from environment variable
// region source is not enforced, but we need it to construct sts endpoint, if we can't find from environment, we should check if it's set in config file.
if (role_arn.empty() || token_file.empty() || tmp_region.empty())
{
auto profile = Aws::Config::GetCachedConfigProfile(Aws::Auth::GetConfigProfileName());
if (tmp_region.empty())
{
tmp_region = profile.GetRegion();
}
// If either of these two were not found from environment, use whatever found for all three in config file
if (role_arn.empty() || token_file.empty())
{
role_arn = profile.GetRoleArn();
token_file = profile.GetValue("web_identity_token_file");
session_name = profile.GetValue("role_session_name");
}
}
if (token_file.empty())
{
LOG_WARNING(logger, "Token file must be specified to use STS AssumeRole web identity creds provider.");
return; // No need to do further constructing
}
else
{
LOG_DEBUG(logger, "Resolved token_file from profile_config or environment variable to be {}", token_file);
}
if (role_arn.empty())
{
LOG_WARNING(logger, "RoleArn must be specified to use STS AssumeRole web identity creds provider.");
return; // No need to do further constructing
}
else
{
LOG_DEBUG(logger, "Resolved role_arn from profile_config or environment variable to be {}", role_arn);
}
if (tmp_region.empty())
{
tmp_region = Aws::Region::US_EAST_1;
}
else
{
LOG_DEBUG(logger, "Resolved region from profile_config or environment variable to be {}", tmp_region);
}
if (session_name.empty())
{
session_name = Aws::Utils::UUID::RandomUUID();
}
else
{
LOG_DEBUG(logger, "Resolved session_name from profile_config or environment variable to be {}", session_name);
}
aws_client_configuration.scheme = Aws::Http::Scheme::HTTPS;
aws_client_configuration.region = tmp_region;
std::vector<String> retryable_errors;
retryable_errors.push_back("IDPCommunicationError");
retryable_errors.push_back("InvalidIdentityToken");
aws_client_configuration.retryStrategy = std::make_shared<Aws::Client::SpecifiedRetryableErrorsRetryStrategy>(
retryable_errors, /* maxRetries = */3);
client = std::make_unique<Aws::Internal::STSCredentialsClient>(aws_client_configuration);
initialized = true;
LOG_INFO(logger, "Creating STS AssumeRole with web identity creds provider.");
}
Aws::Auth::AWSCredentials GetAWSCredentials() override
{
// A valid client means required information like role arn and token file were constructed correctly.
// We can use this provider to load creds, otherwise, we can just return empty creds.
if (!initialized)
{
return Aws::Auth::AWSCredentials();
}
RefreshIfExpired();
Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock);
return credentials;
}
protected:
void Reload() override
{
LOG_INFO(logger, "Credentials have expired, attempting to renew from STS.");
std::ifstream token_stream(token_file.data());
if (token_stream)
{
String token_string((std::istreambuf_iterator<char>(token_stream)), std::istreambuf_iterator<char>());
token = token_string;
}
else
{
LOG_INFO(logger, "Can't open token file: {}", token_file);
return;
}
Aws::Internal::STSCredentialsClient::STSAssumeRoleWithWebIdentityRequest request{session_name, role_arn, token};
auto result = client->GetAssumeRoleWithWebIdentityCredentials(request);
LOG_TRACE(logger, "Successfully retrieved credentials with AWS_ACCESS_KEY: {}", result.creds.GetAWSAccessKeyId());
credentials = result.creds;
}
private:
void RefreshIfExpired()
{
Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock);
if (!credentials.IsExpiredOrEmpty())
{
return;
}
guard.UpgradeToWriterLock();
if (!credentials.IsExpiredOrEmpty()) // double-checked lock to avoid refreshing twice
{
return;
}
Reload();
}
std::unique_ptr<Aws::Internal::STSCredentialsClient> client;
Aws::Auth::AWSCredentials credentials;
Aws::String role_arn;
Aws::String token_file;
Aws::String session_name;
Aws::String token;
bool initialized = false;
Poco::Logger * logger;
};
class S3CredentialsProviderChain : public Aws::Auth::AWSCredentialsProviderChain class S3CredentialsProviderChain : public Aws::Auth::AWSCredentialsProviderChain
{ {
public: public:
@ -381,7 +535,11 @@ public:
AddProvider(std::make_shared<Aws::Auth::EnvironmentAWSCredentialsProvider>()); AddProvider(std::make_shared<Aws::Auth::EnvironmentAWSCredentialsProvider>());
AddProvider(std::make_shared<Aws::Auth::ProfileConfigFileAWSCredentialsProvider>()); AddProvider(std::make_shared<Aws::Auth::ProfileConfigFileAWSCredentialsProvider>());
AddProvider(std::make_shared<Aws::Auth::ProcessCredentialsProvider>()); AddProvider(std::make_shared<Aws::Auth::ProcessCredentialsProvider>());
AddProvider(std::make_shared<Aws::Auth::STSAssumeRoleWebIdentityCredentialsProvider>());
{
DB::S3::PocoHTTPClientConfiguration aws_client_configuration = DB::S3::ClientFactory::instance().createClientConfiguration(configuration.region, configuration.remote_host_filter, configuration.s3_max_redirects);
AddProvider(std::make_shared<AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider>(aws_client_configuration));
}
/// ECS TaskRole Credentials only available when ENVIRONMENT VARIABLE is set. /// ECS TaskRole Credentials only available when ENVIRONMENT VARIABLE is set.
const auto relative_uri = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI); const auto relative_uri = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI);