diff --git a/contrib/aws b/contrib/aws index a220591e335..7d48b2c8193 160000 --- a/contrib/aws +++ b/contrib/aws @@ -1 +1 @@ -Subproject commit a220591e335923ce1c19bbf9eb925787f7ab6c13 +Subproject commit 7d48b2c8193679cc4516e5bd68ae4a64b94dae7d diff --git a/src/Disks/S3/registerDiskS3.cpp b/src/Disks/S3/registerDiskS3.cpp index 1b6086d0019..f9eddebdf88 100644 --- a/src/Disks/S3/registerDiskS3.cpp +++ b/src/Disks/S3/registerDiskS3.cpp @@ -112,32 +112,33 @@ void registerDiskS3(DiskFactory & factory) Poco::File disk{context.getPath() + "disks/" + name}; disk.createDirectories(); - Aws::Client::ClientConfiguration cfg; + S3::PocoHTTPClientConfiguration client_configuration = S3::ClientFactory::instance().createClientConfiguration( + context.getRemoteHostFilter(), + context.getGlobalContext().getSettingsRef().s3_max_redirects); S3::URI uri(Poco::URI(config.getString(config_prefix + ".endpoint"))); if (uri.key.back() != '/') throw Exception("S3 path must ends with '/', but '" + uri.key + "' doesn't.", ErrorCodes::BAD_ARGUMENTS); - cfg.connectTimeoutMs = config.getUInt(config_prefix + ".connect_timeout_ms", 10000); - cfg.httpRequestTimeoutMs = config.getUInt(config_prefix + ".request_timeout_ms", 5000); - cfg.maxConnections = config.getUInt(config_prefix + ".max_connections", 100); - cfg.endpointOverride = uri.endpoint; + client_configuration.connectTimeoutMs = config.getUInt(config_prefix + ".connect_timeout_ms", 10000); + client_configuration.httpRequestTimeoutMs = config.getUInt(config_prefix + ".request_timeout_ms", 5000); + client_configuration.maxConnections = config.getUInt(config_prefix + ".max_connections", 100); + client_configuration.endpointOverride = uri.endpoint; auto proxy_config = getProxyConfiguration(config_prefix, config); if (proxy_config) - cfg.perRequestConfiguration = [proxy_config](const auto & request) { return proxy_config->getConfiguration(request); }; + client_configuration.perRequestConfiguration = [proxy_config](const auto & request) { return proxy_config->getConfiguration(request); }; - cfg.retryStrategy = std::make_shared( + client_configuration.retryStrategy = std::make_shared( config.getUInt(config_prefix + ".retry_attempts", 10)); auto client = S3::ClientFactory::instance().create( - cfg, + client_configuration, uri.is_virtual_hosted_style, config.getString(config_prefix + ".access_key_id", ""), config.getString(config_prefix + ".secret_access_key", ""), - config.getBool(config_prefix + ".use_environment_credentials", config.getBool("s3.use_environment_credentials", false)), - context.getRemoteHostFilter(), - context.getGlobalContext().getSettingsRef().s3_max_redirects); + config.getBool(config_prefix + ".use_environment_credentials", config.getBool("s3.use_environment_credentials", false)) + ); String metadata_path = config.getString(config_prefix + ".metadata_path", context.getPath() + "disks/" + name + "/"); diff --git a/src/IO/S3/PocoHTTPClient.cpp b/src/IO/S3/PocoHTTPClient.cpp index 2389f9a2192..bf6d30986a9 100644 --- a/src/IO/S3/PocoHTTPClient.cpp +++ b/src/IO/S3/PocoHTTPClient.cpp @@ -6,13 +6,11 @@ #include #include -#include #include #include #include #include #include -#include #include #include #include "Poco/StreamCopier.h" @@ -49,11 +47,9 @@ namespace DB::S3 { PocoHTTPClientConfiguration::PocoHTTPClientConfiguration( - const Aws::Client::ClientConfiguration & cfg, const RemoteHostFilter & remote_host_filter_, unsigned int s3_max_redirects_) - : Aws::Client::ClientConfiguration(cfg) - , remote_host_filter(remote_host_filter_) + : remote_host_filter(remote_host_filter_) , s3_max_redirects(s3_max_redirects_) { } @@ -90,29 +86,19 @@ PocoHTTPClient::PocoHTTPClient(const PocoHTTPClientConfiguration & clientConfigu { } -std::shared_ptr PocoHTTPClient::MakeRequest( - Aws::Http::HttpRequest & request, - Aws::Utils::RateLimits::RateLimiterInterface * readLimiter, - Aws::Utils::RateLimits::RateLimiterInterface * writeLimiter) const -{ - auto response = Aws::MakeShared("PocoHTTPClient", request); - makeRequestInternal(request, response, readLimiter, writeLimiter); - return response; -} - std::shared_ptr PocoHTTPClient::MakeRequest( const std::shared_ptr & request, Aws::Utils::RateLimits::RateLimiterInterface * readLimiter, Aws::Utils::RateLimits::RateLimiterInterface * writeLimiter) const { - auto response = Aws::MakeShared("PocoHTTPClient", request); + auto response = Aws::MakeShared("PocoHTTPClient", request); makeRequestInternal(*request, response, readLimiter, writeLimiter); return response; } void PocoHTTPClient::makeRequestInternal( Aws::Http::HttpRequest & request, - std::shared_ptr & response, + std::shared_ptr & response, Aws::Utils::RateLimits::RateLimiterInterface *, Aws::Utils::RateLimits::RateLimiterInterface *) const { @@ -278,7 +264,7 @@ void PocoHTTPClient::makeRequestInternal( } } else - response->GetResponseStream().SetUnderlyingStream(std::make_shared>(session, response_body_stream)); + response->SetResponseBody(response_body_stream, session); return; } diff --git a/src/IO/S3/PocoHTTPClient.h b/src/IO/S3/PocoHTTPClient.h index e4fc453f388..918943a413c 100644 --- a/src/IO/S3/PocoHTTPClient.h +++ b/src/IO/S3/PocoHTTPClient.h @@ -2,9 +2,12 @@ #include #include +#include +#include #include #include #include +#include namespace Aws::Http::Standard { @@ -18,16 +21,52 @@ class Context; namespace DB::S3 { +class ClientFactory; struct PocoHTTPClientConfiguration : public Aws::Client::ClientConfiguration { const RemoteHostFilter & remote_host_filter; unsigned int s3_max_redirects; - PocoHTTPClientConfiguration(const Aws::Client::ClientConfiguration & cfg, const RemoteHostFilter & remote_host_filter_, - unsigned int s3_max_redirects_); - void updateSchemeAndRegion(); + +private: + PocoHTTPClientConfiguration(const RemoteHostFilter & remote_host_filter_, unsigned int s3_max_redirects_); + + /// Constructor of Aws::Client::ClientConfiguration must be called after AWS SDK initialization. + friend ClientFactory; +}; + +class PocoHTTPResponse : public Aws::Http::Standard::StandardHttpResponse +{ +public: + using SessionPtr = PooledHTTPSessionPtr; + + PocoHTTPResponse(const std::shared_ptr request) + : Aws::Http::Standard::StandardHttpResponse(request) + , body_stream(request->GetResponseStreamFactory()) + { + } + + void SetResponseBody(Aws::IStream & incoming_stream, SessionPtr & session_) + { + body_stream = Aws::Utils::Stream::ResponseStream( + Aws::New>("http result streambuf", session_, incoming_stream.rdbuf()) + ); + } + + Aws::IOStream & GetResponseBody() const override + { + return body_stream.GetUnderlyingStream(); + } + + Aws::Utils::Stream::ResponseStream && SwapResponseStreamOwnership() override + { + return std::move(body_stream); + } + +private: + Aws::Utils::Stream::ResponseStream body_stream; }; class PocoHTTPClient : public Aws::Http::HttpClient @@ -35,10 +74,6 @@ class PocoHTTPClient : public Aws::Http::HttpClient public: explicit PocoHTTPClient(const PocoHTTPClientConfiguration & clientConfiguration); ~PocoHTTPClient() override = default; - std::shared_ptr MakeRequest( - Aws::Http::HttpRequest & request, - Aws::Utils::RateLimits::RateLimiterInterface * readLimiter, - Aws::Utils::RateLimits::RateLimiterInterface * writeLimiter) const override; std::shared_ptr MakeRequest( const std::shared_ptr & request, @@ -48,7 +83,7 @@ public: private: void makeRequestInternal( Aws::Http::HttpRequest & request, - std::shared_ptr & response, + std::shared_ptr & response, Aws::Utils::RateLimits::RateLimiterInterface * readLimiter, Aws::Utils::RateLimits::RateLimiterInterface * writeLimiter) const; diff --git a/src/IO/S3/SessionAwareAwsStream.h b/src/IO/S3/SessionAwareAwsStream.h deleted file mode 100644 index f64be5dac16..00000000000 --- a/src/IO/S3/SessionAwareAwsStream.h +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include - -#include - - -namespace DB::S3 -{ -/** - * Wrapper of IOStream to store response stream and corresponding HTTP session. - */ -template -class SessionAwareAwsStream : public Aws::IStream -{ -public: - SessionAwareAwsStream(Session session_, std::istream & response_stream_) - : Aws::IStream(response_stream_.rdbuf()), session(std::move(session_)) - { - } - -private: - /// Poco HTTP session is holder of response stream. - Session session; -}; - -} diff --git a/src/IO/S3/SessionAwareIOStream.h b/src/IO/S3/SessionAwareIOStream.h new file mode 100644 index 00000000000..1640accb6fa --- /dev/null +++ b/src/IO/S3/SessionAwareIOStream.h @@ -0,0 +1,26 @@ +#pragma once + +#include + + +namespace DB::S3 +{ +/** + * Wrapper of IOStream to store response stream and corresponding HTTP session. + */ +template +class SessionAwareIOStream : public std::iostream +{ +public: + SessionAwareIOStream(Session session_, std::streambuf * sb) + : std::iostream(sb) + , session(std::move(session_)) + { + } + +private: + /// Poco HTTP session is holder of response stream. + Session session; +}; + +} diff --git a/src/IO/S3Common.cpp b/src/IO/S3Common.cpp index d4c4ba9bb02..fbcd4ed97f1 100644 --- a/src/IO/S3Common.cpp +++ b/src/IO/S3Common.cpp @@ -144,7 +144,7 @@ public: } else if (Aws::Utils::StringUtils::ToLower(ec2_metadata_disabled.c_str()) != "true") { - Aws::Client::ClientConfiguration aws_client_configuration; + DB::S3::PocoHTTPClientConfiguration aws_client_configuration = DB::S3::ClientFactory::instance().createClientConfiguration(remote_host_filter, s3_max_redirects); /// See MakeDefaultHttpResourceClientConfiguration(). /// This is part of EC2 metadata client, but unfortunately it can't be accessed from outside @@ -165,8 +165,7 @@ public: aws_client_configuration.requestTimeoutMs = 1000; aws_client_configuration.retryStrategy = std::make_shared(1, 1000); - DB::S3::PocoHTTPClientConfiguration client_configuration(aws_client_configuration, remote_host_filter, s3_max_redirects); - auto ec2_metadata_client = std::make_shared(client_configuration); + auto ec2_metadata_client = std::make_shared(aws_client_configuration); auto config_loader = std::make_shared(ec2_metadata_client); AddProvider(std::make_shared(config_loader)); @@ -207,13 +206,32 @@ public: return result; } + bool SignRequest(Aws::Http::HttpRequest & request, const char * region, const char * service_name, bool sign_body) const override + { + auto result = Aws::Client::AWSAuthV4Signer::SignRequest(request, region, service_name, sign_body); + for (const auto & header : headers) + request.SetHeaderValue(header.name, header.value); + return result; + } + bool PresignRequest( Aws::Http::HttpRequest & request, const char * region, - const char * serviceName, long long expiration_time_sec) const override // NOLINT { - auto result = Aws::Client::AWSAuthV4Signer::PresignRequest(request, region, serviceName, expiration_time_sec); + auto result = Aws::Client::AWSAuthV4Signer::PresignRequest(request, region, expiration_time_sec); + for (const auto & header : headers) + request.SetHeaderValue(header.name, header.value); + return result; + } + + bool PresignRequest( + Aws::Http::HttpRequest & request, + const char * region, + const char * service_name, + long long expiration_time_sec) const override // NOLINT + { + auto result = Aws::Client::AWSAuthV4Signer::PresignRequest(request, region, service_name, expiration_time_sec); for (const auto & header : headers) request.SetHeaderValue(header.name, header.value); return result; @@ -265,33 +283,28 @@ namespace S3 const RemoteHostFilter & remote_host_filter, unsigned int s3_max_redirects) { - Aws::Client::ClientConfiguration cfg; + PocoHTTPClientConfiguration client_configuration(remote_host_filter, s3_max_redirects); if (!endpoint.empty()) - cfg.endpointOverride = endpoint; + client_configuration.endpointOverride = endpoint; - return create(cfg, + return create(client_configuration, is_virtual_hosted_style, access_key_id, secret_access_key, - use_environment_credentials, - remote_host_filter, - s3_max_redirects); + use_environment_credentials); } std::shared_ptr ClientFactory::create( // NOLINT - const Aws::Client::ClientConfiguration & cfg, + const PocoHTTPClientConfiguration & cfg_, bool is_virtual_hosted_style, const String & access_key_id, const String & secret_access_key, - bool use_environment_credentials, - const RemoteHostFilter & remote_host_filter, - unsigned int s3_max_redirects) + bool use_environment_credentials) { Aws::Auth::AWSCredentials credentials(access_key_id, secret_access_key); - PocoHTTPClientConfiguration client_configuration(cfg, remote_host_filter, s3_max_redirects); - + PocoHTTPClientConfiguration client_configuration = cfg_; client_configuration.updateSchemeAndRegion(); return std::make_shared( @@ -301,22 +314,19 @@ namespace S3 use_environment_credentials), // AWS credentials provider. std::move(client_configuration), // Client configuration. Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, // Sign policy. - is_virtual_hosted_style || cfg.endpointOverride.empty() // Use virtual addressing if endpoint is not specified. + is_virtual_hosted_style || client_configuration.endpointOverride.empty() // Use virtual addressing if endpoint is not specified. ); } std::shared_ptr ClientFactory::create( // NOLINT - const Aws::Client::ClientConfiguration & cfg, + const PocoHTTPClientConfiguration & cfg_, bool is_virtual_hosted_style, const String & access_key_id, const String & secret_access_key, HeaderCollection headers, - bool use_environment_credentials, - const RemoteHostFilter & remote_host_filter, - unsigned int s3_max_redirects) + bool use_environment_credentials) { - PocoHTTPClientConfiguration client_configuration(cfg, remote_host_filter, s3_max_redirects); - + PocoHTTPClientConfiguration client_configuration = cfg_; client_configuration.updateSchemeAndRegion(); Aws::Auth::AWSCredentials credentials(access_key_id, secret_access_key); @@ -329,6 +339,13 @@ namespace S3 ); } + PocoHTTPClientConfiguration ClientFactory::createClientConfiguration( // NOLINT + const RemoteHostFilter & remote_host_filter, + unsigned int s3_max_redirects) + { + return PocoHTTPClientConfiguration(remote_host_filter, s3_max_redirects); + } + URI::URI(const Poco::URI & uri_) { /// Case when bucket name represented in domain name of S3 URL. diff --git a/src/IO/S3Common.h b/src/IO/S3Common.h index e2ec0785811..c367444395d 100644 --- a/src/IO/S3Common.h +++ b/src/IO/S3Common.h @@ -7,6 +7,7 @@ #include #include #include +#include #include namespace Aws::S3 @@ -23,7 +24,6 @@ namespace DB namespace DB::S3 { - class ClientFactory { public: @@ -41,21 +41,21 @@ public: unsigned int s3_max_redirects); std::shared_ptr create( - const Aws::Client::ClientConfiguration & cfg, + const PocoHTTPClientConfiguration & cfg, bool is_virtual_hosted_style, const String & access_key_id, const String & secret_access_key, - bool use_environment_credentials, - const RemoteHostFilter & remote_host_filter, - unsigned int s3_max_redirects); + bool use_environment_credentials); std::shared_ptr create( - const Aws::Client::ClientConfiguration & cfg, + const PocoHTTPClientConfiguration & cfg, bool is_virtual_hosted_style, const String & access_key_id, const String & secret_access_key, HeaderCollection headers, - bool use_environment_credentials, + bool use_environment_credentials); + + PocoHTTPClientConfiguration createClientConfiguration( const RemoteHostFilter & remote_host_filter, unsigned int s3_max_redirects); diff --git a/src/Storages/StorageS3.cpp b/src/Storages/StorageS3.cpp index 7524cb18f9f..0af115dc0b5 100644 --- a/src/Storages/StorageS3.cpp +++ b/src/Storages/StorageS3.cpp @@ -222,7 +222,10 @@ StorageS3::StorageS3( if (access_key_id_.empty()) credentials = Aws::Auth::AWSCredentials(std::move(settings.access_key_id), std::move(settings.secret_access_key)); - Aws::Client::ClientConfiguration client_configuration; + S3::PocoHTTPClientConfiguration client_configuration = S3::ClientFactory::instance().createClientConfiguration( + context_.getRemoteHostFilter(), + context_.getGlobalContext().getSettingsRef().s3_max_redirects); + client_configuration.endpointOverride = uri_.endpoint; client_configuration.maxConnections = max_connections_; @@ -232,9 +235,8 @@ StorageS3::StorageS3( credentials.GetAWSAccessKeyId(), credentials.GetAWSSecretKey(), std::move(settings.headers), - settings.use_environment_credentials.value_or(global_context.getConfigRef().getBool("s3.use_environment_credentials", false)), - context_.getRemoteHostFilter(), - context_.getGlobalContext().getSettingsRef().s3_max_redirects); + settings.use_environment_credentials.value_or(global_context.getConfigRef().getBool("s3.use_environment_credentials", false)) + ); }