Merge pull request #49390 from ClickHouse/fix-gcs-headers

Properly fix GCS when HMAC is used
This commit is contained in:
Antonio Andelic 2023-05-08 08:17:43 +02:00 committed by GitHub
commit 212c57c034
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 95 additions and 36 deletions

View File

@ -112,6 +112,22 @@ std::unique_ptr<Client> Client::create(const Client & other)
return std::unique_ptr<Client>(new Client(other)); return std::unique_ptr<Client>(new Client(other));
} }
namespace
{
ProviderType deduceProviderType(const std::string & url)
{
if (url.find(".amazonaws.com") != std::string::npos)
return ProviderType::AWS;
if (url.find("storage.googleapis.com") != std::string::npos)
return ProviderType::GCS;
return ProviderType::UNKNOWN;
}
}
Client::Client( Client::Client(
size_t max_redirects_, size_t max_redirects_,
ServerSideEncryptionKMSConfig sse_kms_config_, ServerSideEncryptionKMSConfig sse_kms_config_,
@ -128,9 +144,28 @@ Client::Client(
endpoint_provider->GetBuiltInParameters().GetParameter("Region").GetString(explicit_region); endpoint_provider->GetBuiltInParameters().GetParameter("Region").GetString(explicit_region);
endpoint_provider->GetBuiltInParameters().GetParameter("Endpoint").GetString(initial_endpoint); endpoint_provider->GetBuiltInParameters().GetParameter("Endpoint").GetString(initial_endpoint);
provider_type = getProviderTypeFromURL(initial_endpoint); provider_type = deduceProviderType(initial_endpoint);
LOG_TRACE(log, "Provider type: {}", toString(provider_type)); LOG_TRACE(log, "Provider type: {}", toString(provider_type));
if (provider_type == ProviderType::GCS)
{
/// GCS can operate in 2 modes for header and query params names:
/// - with both x-amz and x-goog prefixes allowed (but cannot mix different prefixes in same request)
/// - only with x-goog prefix
/// first mode is allowed only with HMAC (or unsigned requests) so when we
/// find credential keys we can simply behave as the underlying storage is S3
/// otherwise, we need to be aware we are making requests to GCS
/// and replace all headers with a valid prefix when needed
if (credentials_provider)
{
auto credentials = credentials_provider->GetAWSCredentials();
if (credentials.IsEmpty())
api_mode = ApiMode::GCS;
}
}
LOG_TRACE(log, "API mode: {}", toString(api_mode));
detect_region = provider_type == ProviderType::AWS && explicit_region == Aws::Region::AWS_GLOBAL; detect_region = provider_type == ProviderType::AWS && explicit_region == Aws::Region::AWS_GLOBAL;
cache = std::make_shared<ClientCache>(); cache = std::make_shared<ClientCache>();
@ -208,7 +243,7 @@ Model::HeadObjectOutcome Client::HeadObject(const HeadObjectRequest & request) c
{ {
const auto & bucket = request.GetBucket(); const auto & bucket = request.GetBucket();
request.setProviderType(provider_type); request.setApiMode(api_mode);
if (auto region = getRegionForBucket(bucket); !region.empty()) if (auto region = getRegionForBucket(bucket); !region.empty())
{ {
@ -348,7 +383,7 @@ std::invoke_result_t<RequestFn, RequestType>
Client::doRequest(const RequestType & request, RequestFn request_fn) const Client::doRequest(const RequestType & request, RequestFn request_fn) const
{ {
const auto & bucket = request.GetBucket(); const auto & bucket = request.GetBucket();
request.setProviderType(provider_type); request.setApiMode(api_mode);
if (auto region = getRegionForBucket(bucket); !region.empty()) if (auto region = getRegionForBucket(bucket); !region.empty())
{ {
@ -421,9 +456,23 @@ Client::doRequest(const RequestType & request, RequestFn request_fn) const
throw Exception(ErrorCodes::TOO_MANY_REDIRECTS, "Too many redirects"); throw Exception(ErrorCodes::TOO_MANY_REDIRECTS, "Too many redirects");
} }
ProviderType Client::getProviderType() const bool Client::supportsMultiPartCopy() const
{ {
return provider_type; return provider_type != ProviderType::GCS;
}
void Client::BuildHttpRequest(const Aws::AmazonWebServiceRequest& request,
const std::shared_ptr<Aws::Http::HttpRequest>& httpRequest) const
{
Aws::S3::S3Client::BuildHttpRequest(request, httpRequest);
if (api_mode == ApiMode::GCS)
{
/// some GCS requests don't like S3 specific headers that the client sets
httpRequest->DeleteHeader("x-amz-api-version");
httpRequest->DeleteHeader("amz-sdk-invocation-id");
httpRequest->DeleteHeader("amz-sdk-request");
}
} }
std::string Client::getRegionForBucket(const std::string & bucket, bool force_detect) const std::string Client::getRegionForBucket(const std::string & bucket, bool force_detect) const

View File

@ -190,7 +190,10 @@ public:
using Aws::S3::S3Client::EnableRequestProcessing; using Aws::S3::S3Client::EnableRequestProcessing;
using Aws::S3::S3Client::DisableRequestProcessing; using Aws::S3::S3Client::DisableRequestProcessing;
ProviderType getProviderType() const; void BuildHttpRequest(const Aws::AmazonWebServiceRequest& request,
const std::shared_ptr<Aws::Http::HttpRequest>& httpRequest) const override;
bool supportsMultiPartCopy() const;
private: private:
Client(size_t max_redirects_, Client(size_t max_redirects_,
ServerSideEncryptionKMSConfig sse_kms_config_, ServerSideEncryptionKMSConfig sse_kms_config_,
@ -238,7 +241,12 @@ private:
std::string explicit_region; std::string explicit_region;
mutable bool detect_region = true; mutable bool detect_region = true;
/// provider type can determine if some functionality is supported
/// but for same provider, we would need to generate different headers depending on the
/// mode
/// E.g. GCS can work in AWS mode in some cases and accept headers with x-amz prefix
ProviderType provider_type{ProviderType::UNKNOWN}; ProviderType provider_type{ProviderType::UNKNOWN};
ApiMode api_mode{ApiMode::AWS};
mutable std::shared_ptr<ClientCache> cache; mutable std::shared_ptr<ClientCache> cache;

View File

@ -260,17 +260,6 @@ void PocoHTTPClient::makeRequestInternal(
Poco::Logger * log = &Poco::Logger::get("AWSClient"); Poco::Logger * log = &Poco::Logger::get("AWSClient");
auto uri = request.GetUri().GetURIString(); auto uri = request.GetUri().GetURIString();
#if 0
auto provider_type = getProviderTypeFromURL(uri);
if (provider_type == ProviderType::GCS)
{
/// some GCS requests don't like S3 specific headers that the client sets
request.DeleteHeader("x-amz-api-version");
request.DeleteHeader("amz-sdk-invocation-id");
request.DeleteHeader("amz-sdk-request");
}
#endif
if (enable_s3_requests_logging) if (enable_s3_requests_logging)
LOG_TEST(log, "Make request to: {}", uri); LOG_TEST(log, "Make request to: {}", uri);

View File

@ -22,20 +22,17 @@ std::string_view toString(ProviderType provider_type)
} }
} }
bool supportsMultiPartCopy(ProviderType provider_type) std::string_view toString(ApiMode api_mode)
{ {
return provider_type != ProviderType::GCS; using enum ApiMode;
}
ProviderType getProviderTypeFromURL(const std::string & url) switch (api_mode)
{ {
if (url.find(".amazonaws.com") != std::string::npos) case AWS:
return ProviderType::AWS; return "AWS";
case GCS:
if (url.find("storage.googleapis.com") != std::string::npos) return "GCS";
return ProviderType::GCS; }
return ProviderType::UNKNOWN;
} }
} }

View File

@ -10,6 +10,11 @@
namespace DB::S3 namespace DB::S3
{ {
/// Provider type defines the platform containing the object
/// we are trying to access
/// This information is useful for determining general support for
/// some feature like multipart copy which is currently supported by AWS
/// but not by GCS
enum class ProviderType : uint8_t enum class ProviderType : uint8_t
{ {
AWS, AWS,
@ -19,9 +24,20 @@ enum class ProviderType : uint8_t
std::string_view toString(ProviderType provider_type); std::string_view toString(ProviderType provider_type);
bool supportsMultiPartCopy(ProviderType provider_type); /// Mode in which we can use the XML API
/// This value can be same as the provider type but there can be a difference
/// For example, GCS can work in both
/// AWS compatible mode (accept headers starting with x-amz)
/// and GCS mode (accept only headers starting with x-goog)
/// Because GCS mode is enforced when some features are used we
/// need to have support for both.
enum class ApiMode : uint8_t
{
AWS,
GCS
};
ProviderType getProviderTypeFromURL(const std::string & url); std::string_view toString(ApiMode api_mode);
} }

View File

@ -10,7 +10,7 @@ namespace DB::S3
Aws::Http::HeaderValueCollection CopyObjectRequest::GetRequestSpecificHeaders() const Aws::Http::HeaderValueCollection CopyObjectRequest::GetRequestSpecificHeaders() const
{ {
auto headers = Model::CopyObjectRequest::GetRequestSpecificHeaders(); auto headers = Model::CopyObjectRequest::GetRequestSpecificHeaders();
if (provider_type != ProviderType::GCS) if (api_mode != ApiMode::GCS)
return headers; return headers;
/// GCS supports same headers as S3 but with a prefix x-goog instead of x-amz /// GCS supports same headers as S3 but with a prefix x-goog instead of x-amz

View File

@ -62,15 +62,15 @@ public:
return uri_override; return uri_override;
} }
void setProviderType(ProviderType provider_type_) const void setApiMode(ApiMode api_mode_) const
{ {
provider_type = provider_type_; api_mode = api_mode_;
} }
protected: protected:
mutable std::string region_override; mutable std::string region_override;
mutable std::optional<S3::URI> uri_override; mutable std::optional<S3::URI> uri_override;
mutable ProviderType provider_type{ProviderType::UNKNOWN}; mutable ApiMode api_mode{ApiMode::AWS};
}; };
class CopyObjectRequest : public ExtendedRequest<Model::CopyObjectRequest> class CopyObjectRequest : public ExtendedRequest<Model::CopyObjectRequest>

View File

@ -595,7 +595,7 @@ namespace
, src_key(src_key_) , src_key(src_key_)
, offset(src_offset_) , offset(src_offset_)
, size(src_size_) , size(src_size_)
, supports_multipart_copy(S3::supportsMultiPartCopy(client_ptr_->getProviderType())) , supports_multipart_copy(client_ptr_->supportsMultiPartCopy())
{ {
} }