Properly fix GCS with keys

This commit is contained in:
Antonio Andelic 2023-05-02 10:47:38 +00:00
parent 1267fbca1c
commit 8c91dbdfc6
5 changed files with 49 additions and 26 deletions

View File

@ -111,6 +111,36 @@ std::unique_ptr<Client> Client::create(const Client & other)
return std::unique_ptr<Client>(new Client(other));
}
namespace
{
ProviderType deduceProviderType(const std::string & url, const std::shared_ptr<Aws::Auth::AWSCredentialsProvider> & credentials_provider)
{
/// GCS can operate in 2 modes for header and query params names:
/// - with both x-amz and x-goog prefixes allowed (but cannot mix different prefixes in same request)
/// - only with x-goog prefix
/// first mode is allowed only with HMAC (or unsigned requests) so when we
/// find credential keys we can simply behave as the underlying storage is S3
/// otherwise, we need to be aware we are making requests to GCS
/// and replace all headers with a valid prefix when needed
if (credentials_provider)
{
auto credentials = credentials_provider->GetAWSCredentials();
if (!credentials.IsEmpty())
return ProviderType::AWS;
}
if (url.find(".amazonaws.com") != std::string::npos)
return ProviderType::AWS;
if (url.find("storage.googleapis.com") != std::string::npos)
return ProviderType::GCS;
return ProviderType::UNKNOWN;
}
}
Client::Client(
size_t max_redirects_,
const std::shared_ptr<Aws::Auth::AWSCredentialsProvider> & credentials_provider,
@ -125,10 +155,10 @@ Client::Client(
endpoint_provider->GetBuiltInParameters().GetParameter("Region").GetString(explicit_region);
endpoint_provider->GetBuiltInParameters().GetParameter("Endpoint").GetString(initial_endpoint);
provider_type = getProviderTypeFromURL(initial_endpoint);
provider_type = deduceProviderType(initial_endpoint, credentials_provider);
LOG_TRACE(log, "Provider type: {}", toString(provider_type));
detect_region = provider_type == ProviderType::AWS && explicit_region == Aws::Region::AWS_GLOBAL;
detect_region = initial_endpoint.find(".amazonaws.com") != std::string::npos && explicit_region == Aws::Region::AWS_GLOBAL;
cache = std::make_shared<ClientCache>();
ClientCacheRegistry::instance().registerClient(cache);
@ -400,6 +430,20 @@ ProviderType Client::getProviderType() const
return provider_type;
}
void Client::BuildHttpRequest(const Aws::AmazonWebServiceRequest& request,
const std::shared_ptr<Aws::Http::HttpRequest>& httpRequest) const
{
Aws::S3::S3Client::BuildHttpRequest(request, httpRequest);
if (provider_type == ProviderType::GCS)
{
/// some GCS requests don't like S3 specific headers that the client sets
httpRequest->DeleteHeader("x-amz-api-version");
httpRequest->DeleteHeader("amz-sdk-invocation-id");
httpRequest->DeleteHeader("amz-sdk-request");
}
}
std::string Client::getRegionForBucket(const std::string & bucket, bool force_detect) const
{
std::lock_guard lock(cache->region_cache_mutex);

View File

@ -163,6 +163,9 @@ public:
using Aws::S3::S3Client::DisableRequestProcessing;
ProviderType getProviderType() const;
void BuildHttpRequest(const Aws::AmazonWebServiceRequest& request,
const std::shared_ptr<Aws::Http::HttpRequest>& httpRequest) const override;
private:
Client(size_t max_redirects_,
const std::shared_ptr<Aws::Auth::AWSCredentialsProvider>& credentials_provider,

View File

@ -260,17 +260,6 @@ void PocoHTTPClient::makeRequestInternal(
Poco::Logger * log = &Poco::Logger::get("AWSClient");
auto uri = request.GetUri().GetURIString();
#if 0
auto provider_type = getProviderTypeFromURL(uri);
if (provider_type == ProviderType::GCS)
{
/// some GCS requests don't like S3 specific headers that the client sets
request.DeleteHeader("x-amz-api-version");
request.DeleteHeader("amz-sdk-invocation-id");
request.DeleteHeader("amz-sdk-request");
}
#endif
if (enable_s3_requests_logging)
LOG_TEST(log, "Make request to: {}", uri);

View File

@ -27,17 +27,6 @@ bool supportsMultiPartCopy(ProviderType provider_type)
return provider_type != ProviderType::GCS;
}
ProviderType getProviderTypeFromURL(const std::string & url)
{
if (url.find(".amazonaws.com") != std::string::npos)
return ProviderType::AWS;
if (url.find("storage.googleapis.com") != std::string::npos)
return ProviderType::GCS;
return ProviderType::UNKNOWN;
}
}
#endif

View File

@ -21,8 +21,6 @@ std::string_view toString(ProviderType provider_type);
bool supportsMultiPartCopy(ProviderType provider_type);
ProviderType getProviderTypeFromURL(const std::string & url);
}
#endif