Merge branch 'ClickHouse:master' into http_client_version

This commit is contained in:
Geoff Genz 2023-02-14 09:31:29 -07:00 committed by GitHub
commit f3bcf26959
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 1246 additions and 888 deletions

View File

@ -26,6 +26,22 @@ logging.basicConfig(
total_start_seconds = time.perf_counter()
stage_start_seconds = total_start_seconds
# Thread executor that does not hides exception that happens during function
# execution, and rethrows it after join()
class SafeThread(Thread):
run_exception = None
def run(self):
try:
super().run()
except:
self.run_exception = sys.exc_info()
def join(self):
super().join()
if self.run_exception:
raise self.run_exception[1]
def reportStageEnd(stage):
global stage_start_seconds, total_start_seconds
@ -283,7 +299,7 @@ if not args.use_existing_tables:
print(f"create\t{index}\t{connection.last_query.elapsed}\t{tsv_escape(q)}")
threads = [
Thread(target=do_create, args=(connection, index, create_queries))
SafeThread(target=do_create, args=(connection, index, create_queries))
for index, connection in enumerate(all_connections)
]

View File

@ -155,7 +155,7 @@ if [[ -n "$USE_DATABASE_REPLICATED" ]] && [[ "$USE_DATABASE_REPLICATED" -eq 1 ]]
sudo clickhouse stop --pid-path /var/run/clickhouse-server2 ||:
fi
rg -Fa "Fatal" /var/log/clickhouse-server/clickhouse-server.log ||:
rg -Fa "<Fatal>" /var/log/clickhouse-server/clickhouse-server.log ||:
zstd --threads=0 < /var/log/clickhouse-server/clickhouse-server.log > /test_output/clickhouse-server.log.zst ||:
# FIXME: remove once only github actions will be left
@ -166,8 +166,8 @@ if [[ -n "$WITH_COVERAGE" ]] && [[ "$WITH_COVERAGE" -eq 1 ]]; then
tar --zstd -c -h -f /test_output/clickhouse_coverage.tar.zst /profraw ||:
fi
if [[ -n "$USE_DATABASE_REPLICATED" ]] && [[ "$USE_DATABASE_REPLICATED" -eq 1 ]]; then
rg -Fa "Fatal" /var/log/clickhouse-server/clickhouse-server1.log ||:
rg -Fa "Fatal" /var/log/clickhouse-server/clickhouse-server2.log ||:
rg -Fa "<Fatal>" /var/log/clickhouse-server/clickhouse-server1.log ||:
rg -Fa "<Fatal>" /var/log/clickhouse-server/clickhouse-server2.log ||:
zstd --threads=0 < /var/log/clickhouse-server/clickhouse-server1.log > /test_output/clickhouse-server1.log.zst ||:
zstd --threads=0 < /var/log/clickhouse-server/clickhouse-server2.log > /test_output/clickhouse-server2.log.zst ||:
# FIXME: remove once only github actions will be left

View File

@ -169,7 +169,7 @@ if [[ -n "$USE_DATABASE_REPLICATED" ]] && [[ "$USE_DATABASE_REPLICATED" -eq 1 ]]
sudo clickhouse stop --pid-path /var/run/clickhouse-server2 ||:
fi
rg -Fa "Fatal" /var/log/clickhouse-server/clickhouse-server.log ||:
rg -Fa "<Fatal>" /var/log/clickhouse-server/clickhouse-server.log ||:
zstd --threads=0 < /var/log/clickhouse-server/clickhouse-server.log > /test_output/clickhouse-server.log.zst &
# Compress tables.
@ -215,8 +215,8 @@ fi
tar -chf /test_output/coordination.tar /var/lib/clickhouse/coordination ||:
if [[ -n "$USE_DATABASE_REPLICATED" ]] && [[ "$USE_DATABASE_REPLICATED" -eq 1 ]]; then
rg -Fa "Fatal" /var/log/clickhouse-server/clickhouse-server1.log ||:
rg -Fa "Fatal" /var/log/clickhouse-server/clickhouse-server2.log ||:
rg -Fa "<Fatal>" /var/log/clickhouse-server/clickhouse-server1.log ||:
rg -Fa "<Fatal>" /var/log/clickhouse-server/clickhouse-server2.log ||:
zstd --threads=0 < /var/log/clickhouse-server/clickhouse-server1.log > /test_output/clickhouse-server1.log.zst ||:
zstd --threads=0 < /var/log/clickhouse-server/clickhouse-server2.log > /test_output/clickhouse-server2.log.zst ||:
# FIXME: remove once only github actions will be left

View File

@ -546,6 +546,9 @@ if [ "$DISABLE_BC_CHECK" -ne "1" ]; then
# it uses recently introduced settings which previous versions may not have
rm -f /etc/clickhouse-server/users.d/insert_keeper_retries.xml ||:
# Turn on after 23.1
rm -f /etc/clickhouse-server/users.d/prefetch_settings.xml ||:
start
clickhouse-client --query="SELECT 'Server version: ', version()"
@ -718,7 +721,7 @@ mv /var/log/clickhouse-server/stderr.log /test_output/
# Write check result into check_status.tsv
# Try to choose most specific error for the whole check status
clickhouse-local --structure "test String, res String" -q "SELECT 'failure', test FROM table WHERE res != 'OK' order by
clickhouse-local --structure "test String, res String, time Nullable(Float32), desc String" -q "SELECT 'failure', test FROM table WHERE res != 'OK' order by
(test like 'Backward compatibility check%'), -- BC check goes last
(test like '%Sanitizer%') DESC,
(test like '%Killed by signal%') DESC,
@ -732,7 +735,7 @@ clickhouse-local --structure "test String, res String" -q "SELECT 'failure', tes
(test like '%Error message%') DESC,
(test like '%previous release%') DESC,
rowNumberInAllBlocks()
LIMIT 1" < /test_output/test_results.tsv > /test_output/check_status.tsv
LIMIT 1" < /test_output/test_results.tsv > /test_output/check_status.tsv || echo "failure\tCannot parse test_results.tsv" > /test_output/check_status.tsv
[ -s /test_output/check_status.tsv ] || echo -e "success\tNo errors found" > /test_output/check_status.tsv
# Core dumps

View File

@ -42,7 +42,7 @@ Internal coordination settings are located in the `<keeper_server>.<coordination
- `session_timeout_ms` — Max timeout for client session (ms) (default: 100000).
- `dead_session_check_period_ms` — How often ClickHouse Keeper checks for dead sessions and removes them (ms) (default: 500).
- `heart_beat_interval_ms` — How often a ClickHouse Keeper leader will send heartbeats to followers (ms) (default: 500).
- `election_timeout_lower_bound_ms` — If the follower does not receive a heartbeat from the leader in this interval, then it can initiate leader election (default: 1000).
- `election_timeout_lower_bound_ms` — If the follower does not receive a heartbeat from the leader in this interval, then it can initiate leader election (default: 1000). Must be less than or equal to `election_timeout_upper_bound_ms`. Ideally they shouldn't be equal.
- `election_timeout_upper_bound_ms` — If the follower does not receive a heartbeat from the leader in this interval, then it must initiate leader election (default: 2000).
- `rotate_log_storage_interval` — How many log records to store in a single file (default: 100000).
- `reserved_log_items` — How many coordination log records to store before compaction (default: 100000).

View File

@ -273,6 +273,19 @@ void KeeperServer::launchRaftServer(const Poco::Util::AbstractConfiguration & co
coordination_settings->election_timeout_lower_bound_ms.totalMilliseconds(), "election_timeout_lower_bound_ms", log);
params.election_timeout_upper_bound_ = getValueOrMaxInt32AndLogWarning(
coordination_settings->election_timeout_upper_bound_ms.totalMilliseconds(), "election_timeout_upper_bound_ms", log);
if (params.election_timeout_lower_bound_ || params.election_timeout_upper_bound_)
{
if (params.election_timeout_lower_bound_ >= params.election_timeout_upper_bound_)
{
LOG_FATAL(
log,
"election_timeout_lower_bound_ms is greater than election_timeout_upper_bound_ms, this would disable leader election "
"completely.");
std::terminate();
}
}
params.reserved_log_items_ = getValueOrMaxInt32AndLogWarning(coordination_settings->reserved_log_items, "reserved_log_items", log);
params.snapshot_distance_ = getValueOrMaxInt32AndLogWarning(coordination_settings->snapshot_distance, "snapshot_distance", log);

75
src/IO/S3/AWSLogger.cpp Normal file
View File

@ -0,0 +1,75 @@
#include <IO/S3/AWSLogger.h>
#if USE_AWS_S3
#include <aws/core/utils/logging/LogLevel.h>
namespace
{
const char * S3_LOGGER_TAG_NAMES[][2] = {
{"AWSClient", "AWSClient"},
{"AWSAuthV4Signer", "AWSClient (AWSAuthV4Signer)"},
};
const std::pair<DB::LogsLevel, Poco::Message::Priority> & convertLogLevel(Aws::Utils::Logging::LogLevel log_level)
{
/// We map levels to our own logger 1 to 1 except WARN+ levels. In most cases we failover such errors with retries
/// and don't want to see them as Errors in our logs.
static const std::unordered_map<Aws::Utils::Logging::LogLevel, std::pair<DB::LogsLevel, Poco::Message::Priority>> mapping =
{
{Aws::Utils::Logging::LogLevel::Off, {DB::LogsLevel::none, Poco::Message::PRIO_INFORMATION}},
{Aws::Utils::Logging::LogLevel::Fatal, {DB::LogsLevel::information, Poco::Message::PRIO_INFORMATION}},
{Aws::Utils::Logging::LogLevel::Error, {DB::LogsLevel::information, Poco::Message::PRIO_INFORMATION}},
{Aws::Utils::Logging::LogLevel::Warn, {DB::LogsLevel::information, Poco::Message::PRIO_INFORMATION}},
{Aws::Utils::Logging::LogLevel::Info, {DB::LogsLevel::information, Poco::Message::PRIO_INFORMATION}},
{Aws::Utils::Logging::LogLevel::Debug, {DB::LogsLevel::debug, Poco::Message::PRIO_TEST}},
{Aws::Utils::Logging::LogLevel::Trace, {DB::LogsLevel::trace, Poco::Message::PRIO_TEST}},
};
return mapping.at(log_level);
}
}
namespace DB::S3
{
AWSLogger::AWSLogger(bool enable_s3_requests_logging_)
: enable_s3_requests_logging(enable_s3_requests_logging_)
{
for (auto [tag, name] : S3_LOGGER_TAG_NAMES)
tag_loggers[tag] = &Poco::Logger::get(name);
default_logger = tag_loggers[S3_LOGGER_TAG_NAMES[0][0]];
}
Aws::Utils::Logging::LogLevel AWSLogger::GetLogLevel() const
{
if (enable_s3_requests_logging)
return Aws::Utils::Logging::LogLevel::Trace;
else
return Aws::Utils::Logging::LogLevel::Info;
}
void AWSLogger::Log(Aws::Utils::Logging::LogLevel log_level, const char * tag, const char * format_str, ...) // NOLINT
{
callLogImpl(log_level, tag, format_str); /// FIXME. Variadic arguments?
}
void AWSLogger::LogStream(Aws::Utils::Logging::LogLevel log_level, const char * tag, const Aws::OStringStream & message_stream)
{
callLogImpl(log_level, tag, message_stream.str().c_str());
}
void AWSLogger::callLogImpl(Aws::Utils::Logging::LogLevel log_level, const char * tag, const char * message)
{
const auto & [level, prio] = convertLogLevel(log_level);
if (tag_loggers.contains(tag))
LOG_IMPL(tag_loggers[tag], level, prio, fmt::runtime(message));
else
LOG_IMPL(default_logger, level, prio, "{}: {}", tag, message);
}
}
#endif

37
src/IO/S3/AWSLogger.h Normal file
View File

@ -0,0 +1,37 @@
#pragma once
#include "config.h"
#if USE_AWS_S3
#include <aws/core/utils/logging/LogSystemInterface.h>
#include <Common/logger_useful.h>
namespace DB::S3
{
class AWSLogger final : public Aws::Utils::Logging::LogSystemInterface
{
public:
explicit AWSLogger(bool enable_s3_requests_logging_);
~AWSLogger() final = default;
Aws::Utils::Logging::LogLevel GetLogLevel() const final;
void Log(Aws::Utils::Logging::LogLevel log_level, const char * tag, const char * format_str, ...) final; // NOLINT
void LogStream(Aws::Utils::Logging::LogLevel log_level, const char * tag, const Aws::OStringStream & message_stream) final;
void callLogImpl(Aws::Utils::Logging::LogLevel log_level, const char * tag, const char * message);
void Flush() final {}
private:
Poco::Logger * default_logger;
bool enable_s3_requests_logging;
std::unordered_map<String, Poco::Logger *> tag_loggers;
};
}
#endif

View File

@ -9,9 +9,13 @@
#include <aws/s3/model/ListObjectsV2Request.h>
#include <aws/core/client/AWSErrorMarshaller.h>
#include <aws/core/endpoint/EndpointParameter.h>
#include <aws/core/utils/HashingUtils.h>
#include <IO/S3Common.h>
#include <IO/S3/Requests.h>
#include <IO/S3/PocoHTTPClientFactory.h>
#include <IO/S3/AWSLogger.h>
#include <IO/S3/Credentials.h>
#include <Common/assert_cast.h>
@ -393,6 +397,93 @@ void ClientCacheRegistry::clearCacheForAll()
}
ClientFactory::ClientFactory()
{
aws_options = Aws::SDKOptions{};
Aws::InitAPI(aws_options);
Aws::Utils::Logging::InitializeAWSLogging(std::make_shared<AWSLogger>(false));
Aws::Http::SetHttpClientFactory(std::make_shared<PocoHTTPClientFactory>());
}
ClientFactory::~ClientFactory()
{
Aws::Utils::Logging::ShutdownAWSLogging();
Aws::ShutdownAPI(aws_options);
}
ClientFactory & ClientFactory::instance()
{
static ClientFactory ret;
return ret;
}
std::unique_ptr<S3::Client> ClientFactory::create( // NOLINT
const PocoHTTPClientConfiguration & cfg_,
bool is_virtual_hosted_style,
const String & access_key_id,
const String & secret_access_key,
const String & server_side_encryption_customer_key_base64,
HTTPHeaderEntries headers,
bool use_environment_credentials,
bool use_insecure_imds_request)
{
PocoHTTPClientConfiguration client_configuration = cfg_;
client_configuration.updateSchemeAndRegion();
if (!server_side_encryption_customer_key_base64.empty())
{
/// See Client::GeneratePresignedUrlWithSSEC().
headers.push_back({Aws::S3::SSEHeaders::SERVER_SIDE_ENCRYPTION_CUSTOMER_ALGORITHM,
Aws::S3::Model::ServerSideEncryptionMapper::GetNameForServerSideEncryption(Aws::S3::Model::ServerSideEncryption::AES256)});
headers.push_back({Aws::S3::SSEHeaders::SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY,
server_side_encryption_customer_key_base64});
Aws::Utils::ByteBuffer buffer = Aws::Utils::HashingUtils::Base64Decode(server_side_encryption_customer_key_base64);
String str_buffer(reinterpret_cast<char *>(buffer.GetUnderlyingData()), buffer.GetLength());
headers.push_back({Aws::S3::SSEHeaders::SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY_MD5,
Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateMD5(str_buffer))});
}
client_configuration.extra_headers = std::move(headers);
Aws::Auth::AWSCredentials credentials(access_key_id, secret_access_key);
auto credentials_provider = std::make_shared<S3CredentialsProviderChain>(
client_configuration,
std::move(credentials),
use_environment_credentials,
use_insecure_imds_request);
client_configuration.retryStrategy = std::make_shared<Client::RetryStrategy>(std::move(client_configuration.retryStrategy));
return Client::create(
client_configuration.s3_max_redirects,
std::move(credentials_provider),
std::move(client_configuration), // Client configuration.
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
is_virtual_hosted_style || client_configuration.endpointOverride.empty() /// Use virtual addressing if endpoint is not specified.
);
}
PocoHTTPClientConfiguration ClientFactory::createClientConfiguration( // NOLINT
const String & force_region,
const RemoteHostFilter & remote_host_filter,
unsigned int s3_max_redirects,
bool enable_s3_requests_logging,
bool for_disk_s3,
const ThrottlerPtr & get_request_throttler,
const ThrottlerPtr & put_request_throttler)
{
return PocoHTTPClientConfiguration(
force_region,
remote_host_filter,
s3_max_redirects,
enable_s3_requests_logging,
for_disk_s3,
get_request_throttler,
put_request_throttler);
}
}
}

View File

@ -10,7 +10,9 @@
#include <IO/S3/URI.h>
#include <IO/S3/Requests.h>
#include <IO/S3/PocoHTTPClient.h>
#include <aws/core/Aws.h>
#include <aws/core/client/DefaultRetryStrategy.h>
#include <aws/s3/S3Client.h>
#include <aws/s3/S3ServiceClientModel.h>
@ -302,6 +304,39 @@ private:
Poco::Logger * log;
};
class ClientFactory
{
public:
~ClientFactory();
static ClientFactory & instance();
std::unique_ptr<S3::Client> create(
const PocoHTTPClientConfiguration & cfg,
bool is_virtual_hosted_style,
const String & access_key_id,
const String & secret_access_key,
const String & server_side_encryption_customer_key_base64,
HTTPHeaderEntries headers,
bool use_environment_credentials,
bool use_insecure_imds_request);
PocoHTTPClientConfiguration createClientConfiguration(
const String & force_region,
const RemoteHostFilter & remote_host_filter,
unsigned int s3_max_redirects,
bool enable_s3_requests_logging,
bool for_disk_s3,
const ThrottlerPtr & get_request_throttler,
const ThrottlerPtr & put_request_throttler);
private:
ClientFactory();
Aws::SDKOptions aws_options;
std::atomic<bool> s3_requests_logging_enabled;
};
}
}

522
src/IO/S3/Credentials.cpp Normal file
View File

@ -0,0 +1,522 @@
#include <IO/S3/Credentials.h>
#if USE_AWS_S3
# include <aws/core/Version.h>
# include <aws/core/platform/OSVersionInfo.h>
# include <aws/core/auth/STSCredentialsProvider.h>
# include <aws/core/platform/Environment.h>
# include <aws/core/client/SpecifiedRetryableErrorsRetryStrategy.h>
# include <aws/core/utils/json/JsonSerializer.h>
# include <aws/core/utils/UUID.h>
# include <aws/core/http/HttpClientFactory.h>
# include <Common/logger_useful.h>
# include <IO/S3/PocoHTTPClient.h>
# include <IO/S3/PocoHTTPClientFactory.h>
# include <IO/S3/Client.h>
# include <fstream>
namespace DB::S3
{
AWSEC2MetadataClient::AWSEC2MetadataClient(const Aws::Client::ClientConfiguration & client_configuration, const char * endpoint_)
: Aws::Internal::AWSHttpResourceClient(client_configuration)
, endpoint(endpoint_)
, logger(&Poco::Logger::get("AWSEC2InstanceProfileConfigLoader"))
{
}
Aws::String AWSEC2MetadataClient::GetResource(const char * resource_path) const
{
return GetResource(endpoint.c_str(), resource_path, nullptr/*authToken*/);
}
Aws::String AWSEC2MetadataClient::getDefaultCredentials() const
{
String credentials_string;
{
std::lock_guard locker(token_mutex);
LOG_TRACE(logger, "Getting default credentials for ec2 instance from {}", endpoint);
auto result = GetResourceWithAWSWebServiceResult(endpoint.c_str(), EC2_SECURITY_CREDENTIALS_RESOURCE, nullptr);
credentials_string = result.GetPayload();
if (result.GetResponseCode() == Aws::Http::HttpResponseCode::UNAUTHORIZED)
{
return {};
}
}
String trimmed_credentials_string = Aws::Utils::StringUtils::Trim(credentials_string.c_str());
if (trimmed_credentials_string.empty())
return {};
std::vector<String> security_credentials = Aws::Utils::StringUtils::Split(trimmed_credentials_string, '\n');
LOG_DEBUG(logger, "Calling EC2MetadataService resource, {} returned credential string {}.",
EC2_SECURITY_CREDENTIALS_RESOURCE, trimmed_credentials_string);
if (security_credentials.empty())
{
LOG_WARNING(logger, "Initial call to EC2MetadataService to get credentials failed.");
return {};
}
Aws::StringStream ss;
ss << EC2_SECURITY_CREDENTIALS_RESOURCE << "/" << security_credentials[0];
LOG_DEBUG(logger, "Calling EC2MetadataService resource {}.", ss.str());
return GetResource(ss.str().c_str());
}
Aws::String AWSEC2MetadataClient::awsComputeUserAgentString()
{
Aws::StringStream ss;
ss << "aws-sdk-cpp/" << Aws::Version::GetVersionString() << " " << Aws::OSVersionInfo::ComputeOSVersionString()
<< " " << Aws::Version::GetCompilerVersionString();
return ss.str();
}
Aws::String AWSEC2MetadataClient::getDefaultCredentialsSecurely() const
{
String user_agent_string = awsComputeUserAgentString();
String new_token;
{
std::lock_guard locker(token_mutex);
Aws::StringStream ss;
ss << endpoint << EC2_IMDS_TOKEN_RESOURCE;
std::shared_ptr<Aws::Http::HttpRequest> token_request(Aws::Http::CreateHttpRequest(ss.str(), Aws::Http::HttpMethod::HTTP_PUT,
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod));
token_request->SetHeaderValue(EC2_IMDS_TOKEN_TTL_HEADER, EC2_IMDS_TOKEN_TTL_DEFAULT_VALUE);
token_request->SetUserAgent(user_agent_string);
LOG_TRACE(logger, "Calling EC2MetadataService to get token.");
auto result = GetResourceWithAWSWebServiceResult(token_request);
const String & token_string = result.GetPayload();
new_token = Aws::Utils::StringUtils::Trim(token_string.c_str());
if (result.GetResponseCode() == Aws::Http::HttpResponseCode::BAD_REQUEST)
{
return {};
}
else if (result.GetResponseCode() != Aws::Http::HttpResponseCode::OK || new_token.empty())
{
LOG_TRACE(logger, "Calling EC2MetadataService to get token failed, falling back to less secure way.");
return getDefaultCredentials();
}
token = new_token;
}
String url = endpoint + EC2_SECURITY_CREDENTIALS_RESOURCE;
std::shared_ptr<Aws::Http::HttpRequest> profile_request(Aws::Http::CreateHttpRequest(url,
Aws::Http::HttpMethod::HTTP_GET,
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod));
profile_request->SetHeaderValue(EC2_IMDS_TOKEN_HEADER, new_token);
profile_request->SetUserAgent(user_agent_string);
String profile_string = GetResourceWithAWSWebServiceResult(profile_request).GetPayload();
String trimmed_profile_string = Aws::Utils::StringUtils::Trim(profile_string.c_str());
std::vector<String> security_credentials = Aws::Utils::StringUtils::Split(trimmed_profile_string, '\n');
LOG_DEBUG(logger, "Calling EC2MetadataService resource, {} with token returned profile string {}.",
EC2_SECURITY_CREDENTIALS_RESOURCE, trimmed_profile_string);
if (security_credentials.empty())
{
LOG_WARNING(logger, "Calling EC2Metadataservice to get profiles failed.");
return {};
}
Aws::StringStream ss;
ss << endpoint << EC2_SECURITY_CREDENTIALS_RESOURCE << "/" << security_credentials[0];
std::shared_ptr<Aws::Http::HttpRequest> credentials_request(Aws::Http::CreateHttpRequest(ss.str(),
Aws::Http::HttpMethod::HTTP_GET,
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod));
credentials_request->SetHeaderValue(EC2_IMDS_TOKEN_HEADER, new_token);
credentials_request->SetUserAgent(user_agent_string);
LOG_DEBUG(logger, "Calling EC2MetadataService resource {} with token.", ss.str());
return GetResourceWithAWSWebServiceResult(credentials_request).GetPayload();
}
Aws::String AWSEC2MetadataClient::getCurrentRegion() const
{
return Aws::Region::AWS_GLOBAL;
}
std::shared_ptr<AWSEC2MetadataClient> InitEC2MetadataClient(const Aws::Client::ClientConfiguration & client_configuration)
{
Aws::String ec2_metadata_service_endpoint = Aws::Environment::GetEnv("AWS_EC2_METADATA_SERVICE_ENDPOINT");
auto * logger = &Poco::Logger::get("AWSEC2InstanceProfileConfigLoader");
if (ec2_metadata_service_endpoint.empty())
{
Aws::String ec2_metadata_service_endpoint_mode = Aws::Environment::GetEnv("AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE");
if (ec2_metadata_service_endpoint_mode.length() == 0)
{
ec2_metadata_service_endpoint = "http://169.254.169.254"; //default to IPv4 default endpoint
}
else
{
if (ec2_metadata_service_endpoint_mode.length() == 4)
{
if (Aws::Utils::StringUtils::CaselessCompare(ec2_metadata_service_endpoint_mode.c_str(), "ipv4"))
{
ec2_metadata_service_endpoint = "http://169.254.169.254"; //default to IPv4 default endpoint
}
else if (Aws::Utils::StringUtils::CaselessCompare(ec2_metadata_service_endpoint_mode.c_str(), "ipv6"))
{
ec2_metadata_service_endpoint = "http://[fd00:ec2::254]";
}
else
{
LOG_ERROR(logger, "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE can only be set to ipv4 or ipv6, received: {}", ec2_metadata_service_endpoint_mode);
}
}
else
{
LOG_ERROR(logger, "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE can only be set to ipv4 or ipv6, received: {}", ec2_metadata_service_endpoint_mode);
}
}
}
LOG_INFO(logger, "Using IMDS endpoint: {}", ec2_metadata_service_endpoint);
return std::make_shared<AWSEC2MetadataClient>(client_configuration, ec2_metadata_service_endpoint.c_str());
}
AWSEC2InstanceProfileConfigLoader::AWSEC2InstanceProfileConfigLoader(const std::shared_ptr<AWSEC2MetadataClient> & client_, bool use_secure_pull_)
: client(client_)
, use_secure_pull(use_secure_pull_)
, logger(&Poco::Logger::get("AWSEC2InstanceProfileConfigLoader"))
{
}
bool AWSEC2InstanceProfileConfigLoader::LoadInternal()
{
auto credentials_str = use_secure_pull ? client->getDefaultCredentialsSecurely() : client->getDefaultCredentials();
/// See EC2InstanceProfileConfigLoader.
if (credentials_str.empty())
return false;
Aws::Utils::Json::JsonValue credentials_doc(credentials_str);
if (!credentials_doc.WasParseSuccessful())
{
LOG_ERROR(logger, "Failed to parse output from EC2MetadataService.");
return false;
}
String access_key, secret_key, token;
auto credentials_view = credentials_doc.View();
access_key = credentials_view.GetString("AccessKeyId");
LOG_TRACE(logger, "Successfully pulled credentials from EC2MetadataService with access key.");
secret_key = credentials_view.GetString("SecretAccessKey");
token = credentials_view.GetString("Token");
auto region = client->getCurrentRegion();
Aws::Config::Profile profile;
profile.SetCredentials(Aws::Auth::AWSCredentials(access_key, secret_key, token));
profile.SetRegion(region);
profile.SetName(Aws::Config::INSTANCE_PROFILE_KEY);
m_profiles[Aws::Config::INSTANCE_PROFILE_KEY] = profile;
return true;
}
AWSInstanceProfileCredentialsProvider::AWSInstanceProfileCredentialsProvider(const std::shared_ptr<AWSEC2InstanceProfileConfigLoader> & config_loader)
: ec2_metadata_config_loader(config_loader)
, load_frequency_ms(Aws::Auth::REFRESH_THRESHOLD)
, logger(&Poco::Logger::get("AWSInstanceProfileCredentialsProvider"))
{
LOG_INFO(logger, "Creating Instance with injected EC2MetadataClient and refresh rate.");
}
Aws::Auth::AWSCredentials AWSInstanceProfileCredentialsProvider::GetAWSCredentials()
{
refreshIfExpired();
Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock);
auto profile_it = ec2_metadata_config_loader->GetProfiles().find(Aws::Config::INSTANCE_PROFILE_KEY);
if (profile_it != ec2_metadata_config_loader->GetProfiles().end())
{
return profile_it->second.GetCredentials();
}
return Aws::Auth::AWSCredentials();
}
void AWSInstanceProfileCredentialsProvider::Reload()
{
LOG_INFO(logger, "Credentials have expired attempting to repull from EC2 Metadata Service.");
ec2_metadata_config_loader->Load();
AWSCredentialsProvider::Reload();
}
void AWSInstanceProfileCredentialsProvider::refreshIfExpired()
{
LOG_DEBUG(logger, "Checking if latest credential pull has expired.");
Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock);
if (!IsTimeToRefresh(load_frequency_ms))
{
return;
}
guard.UpgradeToWriterLock();
if (!IsTimeToRefresh(load_frequency_ms)) // double-checked lock to avoid refreshing twice
{
return;
}
Reload();
}
AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider::AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider(DB::S3::PocoHTTPClientConfiguration & aws_client_configuration)
: logger(&Poco::Logger::get("AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider"))
{
// check environment variables
String tmp_region = Aws::Environment::GetEnv("AWS_DEFAULT_REGION");
role_arn = Aws::Environment::GetEnv("AWS_ROLE_ARN");
token_file = Aws::Environment::GetEnv("AWS_WEB_IDENTITY_TOKEN_FILE");
session_name = Aws::Environment::GetEnv("AWS_ROLE_SESSION_NAME");
// check profile_config if either m_roleArn or m_tokenFile is not loaded from environment variable
// region source is not enforced, but we need it to construct sts endpoint, if we can't find from environment, we should check if it's set in config file.
if (role_arn.empty() || token_file.empty() || tmp_region.empty())
{
auto profile = Aws::Config::GetCachedConfigProfile(Aws::Auth::GetConfigProfileName());
if (tmp_region.empty())
{
tmp_region = profile.GetRegion();
}
// If either of these two were not found from environment, use whatever found for all three in config file
if (role_arn.empty() || token_file.empty())
{
role_arn = profile.GetRoleArn();
token_file = profile.GetValue("web_identity_token_file");
session_name = profile.GetValue("role_session_name");
}
}
if (token_file.empty())
{
LOG_WARNING(logger, "Token file must be specified to use STS AssumeRole web identity creds provider.");
return; // No need to do further constructing
}
else
{
LOG_DEBUG(logger, "Resolved token_file from profile_config or environment variable to be {}", token_file);
}
if (role_arn.empty())
{
LOG_WARNING(logger, "RoleArn must be specified to use STS AssumeRole web identity creds provider.");
return; // No need to do further constructing
}
else
{
LOG_DEBUG(logger, "Resolved role_arn from profile_config or environment variable to be {}", role_arn);
}
if (tmp_region.empty())
{
tmp_region = Aws::Region::US_EAST_1;
}
else
{
LOG_DEBUG(logger, "Resolved region from profile_config or environment variable to be {}", tmp_region);
}
if (session_name.empty())
{
session_name = Aws::Utils::UUID::RandomUUID();
}
else
{
LOG_DEBUG(logger, "Resolved session_name from profile_config or environment variable to be {}", session_name);
}
aws_client_configuration.scheme = Aws::Http::Scheme::HTTPS;
aws_client_configuration.region = tmp_region;
std::vector<String> retryable_errors;
retryable_errors.push_back("IDPCommunicationError");
retryable_errors.push_back("InvalidIdentityToken");
aws_client_configuration.retryStrategy = std::make_shared<Aws::Client::SpecifiedRetryableErrorsRetryStrategy>(
retryable_errors, /* maxRetries = */3);
client = std::make_unique<Aws::Internal::STSCredentialsClient>(aws_client_configuration);
initialized = true;
LOG_INFO(logger, "Creating STS AssumeRole with web identity creds provider.");
}
Aws::Auth::AWSCredentials AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider::GetAWSCredentials()
{
// A valid client means required information like role arn and token file were constructed correctly.
// We can use this provider to load creds, otherwise, we can just return empty creds.
if (!initialized)
{
return Aws::Auth::AWSCredentials();
}
refreshIfExpired();
Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock);
return credentials;
}
void AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider::Reload()
{
LOG_INFO(logger, "Credentials have expired, attempting to renew from STS.");
std::ifstream token_stream(token_file.data());
if (token_stream)
{
String token_string((std::istreambuf_iterator<char>(token_stream)), std::istreambuf_iterator<char>());
token = token_string;
}
else
{
LOG_INFO(logger, "Can't open token file: {}", token_file);
return;
}
Aws::Internal::STSCredentialsClient::STSAssumeRoleWithWebIdentityRequest request{session_name, role_arn, token};
auto result = client->GetAssumeRoleWithWebIdentityCredentials(request);
LOG_TRACE(logger, "Successfully retrieved credentials.");
credentials = result.creds;
}
void AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider::refreshIfExpired()
{
Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock);
if (!credentials.IsExpiredOrEmpty())
{
return;
}
guard.UpgradeToWriterLock();
if (!credentials.IsExpiredOrEmpty()) // double-checked lock to avoid refreshing twice
{
return;
}
Reload();
}
S3CredentialsProviderChain::S3CredentialsProviderChain(
const DB::S3::PocoHTTPClientConfiguration & configuration,
const Aws::Auth::AWSCredentials & credentials,
bool use_environment_credentials,
bool use_insecure_imds_request)
{
auto * logger = &Poco::Logger::get("S3CredentialsProviderChain");
/// add explicit credentials to the front of the chain
/// because it's manually defined by the user
if (!credentials.IsEmpty())
{
AddProvider(std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(credentials));
return;
}
if (use_environment_credentials)
{
static const char AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI[] = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI";
static const char AWS_ECS_CONTAINER_CREDENTIALS_FULL_URI[] = "AWS_CONTAINER_CREDENTIALS_FULL_URI";
static const char AWS_ECS_CONTAINER_AUTHORIZATION_TOKEN[] = "AWS_CONTAINER_AUTHORIZATION_TOKEN";
static const char AWS_EC2_METADATA_DISABLED[] = "AWS_EC2_METADATA_DISABLED";
/// The only difference from DefaultAWSCredentialsProviderChain::DefaultAWSCredentialsProviderChain()
/// is that this chain uses custom ClientConfiguration. Also we removed process provider because it's useless in our case.
///
/// AWS API tries credentials providers one by one. Some of providers (like ProfileConfigFileAWSCredentialsProvider) can be
/// quite verbose even if nobody configured them. So we use our provider first and only after it use default providers.
{
DB::S3::PocoHTTPClientConfiguration aws_client_configuration = DB::S3::ClientFactory::instance().createClientConfiguration(
configuration.region,
configuration.remote_host_filter,
configuration.s3_max_redirects,
configuration.enable_s3_requests_logging,
configuration.for_disk_s3,
configuration.get_request_throttler,
configuration.put_request_throttler);
AddProvider(std::make_shared<AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider>(aws_client_configuration));
}
AddProvider(std::make_shared<Aws::Auth::EnvironmentAWSCredentialsProvider>());
/// ECS TaskRole Credentials only available when ENVIRONMENT VARIABLE is set.
const auto relative_uri = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI);
LOG_DEBUG(logger, "The environment variable value {} is {}", AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI,
relative_uri);
const auto absolute_uri = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_CREDENTIALS_FULL_URI);
LOG_DEBUG(logger, "The environment variable value {} is {}", AWS_ECS_CONTAINER_CREDENTIALS_FULL_URI,
absolute_uri);
const auto ec2_metadata_disabled = Aws::Environment::GetEnv(AWS_EC2_METADATA_DISABLED);
LOG_DEBUG(logger, "The environment variable value {} is {}", AWS_EC2_METADATA_DISABLED,
ec2_metadata_disabled);
if (!relative_uri.empty())
{
AddProvider(std::make_shared<Aws::Auth::TaskRoleCredentialsProvider>(relative_uri.c_str()));
LOG_INFO(logger, "Added ECS metadata service credentials provider with relative path: [{}] to the provider chain.",
relative_uri);
}
else if (!absolute_uri.empty())
{
const auto token = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_AUTHORIZATION_TOKEN);
AddProvider(std::make_shared<Aws::Auth::TaskRoleCredentialsProvider>(absolute_uri.c_str(), token.c_str()));
/// DO NOT log the value of the authorization token for security purposes.
LOG_INFO(logger, "Added ECS credentials provider with URI: [{}] to the provider chain with a{} authorization token.",
absolute_uri, token.empty() ? "n empty" : " non-empty");
}
else if (Aws::Utils::StringUtils::ToLower(ec2_metadata_disabled.c_str()) != "true")
{
DB::S3::PocoHTTPClientConfiguration aws_client_configuration = DB::S3::ClientFactory::instance().createClientConfiguration(
configuration.region,
configuration.remote_host_filter,
configuration.s3_max_redirects,
configuration.enable_s3_requests_logging,
configuration.for_disk_s3,
configuration.get_request_throttler,
configuration.put_request_throttler);
/// See MakeDefaultHttpResourceClientConfiguration().
/// This is part of EC2 metadata client, but unfortunately it can't be accessed from outside
/// of contrib/aws/aws-cpp-sdk-core/source/internal/AWSHttpResourceClient.cpp
aws_client_configuration.maxConnections = 2;
aws_client_configuration.scheme = Aws::Http::Scheme::HTTP;
/// Explicitly set the proxy settings to empty/zero to avoid relying on defaults that could potentially change
/// in the future.
aws_client_configuration.proxyHost = "";
aws_client_configuration.proxyUserName = "";
aws_client_configuration.proxyPassword = "";
aws_client_configuration.proxyPort = 0;
/// EC2MetadataService throttles by delaying the response so the service client should set a large read timeout.
/// EC2MetadataService delay is in order of seconds so it only make sense to retry after a couple of seconds.
aws_client_configuration.connectTimeoutMs = 1000;
aws_client_configuration.requestTimeoutMs = 1000;
aws_client_configuration.retryStrategy = std::make_shared<Aws::Client::DefaultRetryStrategy>(1, 1000);
auto ec2_metadata_client = InitEC2MetadataClient(aws_client_configuration);
auto config_loader = std::make_shared<AWSEC2InstanceProfileConfigLoader>(ec2_metadata_client, !use_insecure_imds_request);
AddProvider(std::make_shared<AWSInstanceProfileCredentialsProvider>(config_loader));
LOG_INFO(logger, "Added EC2 metadata service credentials provider to the provider chain.");
}
}
/// Quite verbose provider (argues if file with credentials doesn't exist) so iut's the last one
/// in chain.
AddProvider(std::make_shared<Aws::Auth::ProfileConfigFileAWSCredentialsProvider>());
}
}
#endif

127
src/IO/S3/Credentials.h Normal file
View File

@ -0,0 +1,127 @@
#pragma once
#include "config.h"
#if USE_AWS_S3
# include <aws/core/client/ClientConfiguration.h>
# include <aws/core/internal/AWSHttpResourceClient.h>
# include <aws/core/config/AWSProfileConfigLoader.h>
# include <aws/core/auth/AWSCredentialsProvider.h>
# include <aws/core/auth/AWSCredentialsProviderChain.h>
# include <Common/logger_useful.h>
# include <IO/S3/PocoHTTPClient.h>
namespace DB::S3
{
class AWSEC2MetadataClient : public Aws::Internal::AWSHttpResourceClient
{
static constexpr char EC2_SECURITY_CREDENTIALS_RESOURCE[] = "/latest/meta-data/iam/security-credentials";
static constexpr char EC2_IMDS_TOKEN_RESOURCE[] = "/latest/api/token";
static constexpr char EC2_IMDS_TOKEN_HEADER[] = "x-aws-ec2-metadata-token";
static constexpr char EC2_IMDS_TOKEN_TTL_DEFAULT_VALUE[] = "21600";
static constexpr char EC2_IMDS_TOKEN_TTL_HEADER[] = "x-aws-ec2-metadata-token-ttl-seconds";
public:
/// See EC2MetadataClient.
explicit AWSEC2MetadataClient(const Aws::Client::ClientConfiguration & client_configuration, const char * endpoint_);
AWSEC2MetadataClient& operator =(const AWSEC2MetadataClient & rhs) = delete;
AWSEC2MetadataClient(const AWSEC2MetadataClient & rhs) = delete;
AWSEC2MetadataClient& operator =(const AWSEC2MetadataClient && rhs) = delete;
AWSEC2MetadataClient(const AWSEC2MetadataClient && rhs) = delete;
~AWSEC2MetadataClient() override = default;
using Aws::Internal::AWSHttpResourceClient::GetResource;
virtual Aws::String GetResource(const char * resource_path) const;
virtual Aws::String getDefaultCredentials() const;
static Aws::String awsComputeUserAgentString();
virtual Aws::String getDefaultCredentialsSecurely() const;
virtual Aws::String getCurrentRegion() const;
private:
const Aws::String endpoint;
mutable std::recursive_mutex token_mutex;
mutable Aws::String token;
Poco::Logger * logger;
};
std::shared_ptr<AWSEC2MetadataClient> InitEC2MetadataClient(const Aws::Client::ClientConfiguration & client_configuration);
class AWSEC2InstanceProfileConfigLoader : public Aws::Config::AWSProfileConfigLoader
{
public:
explicit AWSEC2InstanceProfileConfigLoader(const std::shared_ptr<AWSEC2MetadataClient> & client_, bool use_secure_pull_);
~AWSEC2InstanceProfileConfigLoader() override = default;
protected:
bool LoadInternal() override;
private:
std::shared_ptr<AWSEC2MetadataClient> client;
bool use_secure_pull;
Poco::Logger * logger;
};
class AWSInstanceProfileCredentialsProvider : public Aws::Auth::AWSCredentialsProvider
{
public:
/// See InstanceProfileCredentialsProvider.
explicit AWSInstanceProfileCredentialsProvider(const std::shared_ptr<AWSEC2InstanceProfileConfigLoader> & config_loader);
Aws::Auth::AWSCredentials GetAWSCredentials() override;
protected:
void Reload() override;
private:
void refreshIfExpired();
std::shared_ptr<AWSEC2InstanceProfileConfigLoader> ec2_metadata_config_loader;
Int64 load_frequency_ms;
Poco::Logger * logger;
};
class AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider : public Aws::Auth::AWSCredentialsProvider
{
/// See STSAssumeRoleWebIdentityCredentialsProvider.
public:
explicit AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider(DB::S3::PocoHTTPClientConfiguration & aws_client_configuration);
Aws::Auth::AWSCredentials GetAWSCredentials() override;
protected:
void Reload() override;
private:
void refreshIfExpired();
std::unique_ptr<Aws::Internal::STSCredentialsClient> client;
Aws::Auth::AWSCredentials credentials;
Aws::String role_arn;
Aws::String token_file;
Aws::String session_name;
Aws::String token;
bool initialized = false;
Poco::Logger * logger;
};
class S3CredentialsProviderChain : public Aws::Auth::AWSCredentialsProviderChain
{
public:
S3CredentialsProviderChain(const DB::S3::PocoHTTPClientConfiguration & configuration, const Aws::Auth::AWSCredentials & credentials, bool use_environment_credentials, bool use_insecure_imds_request);
};
}
#endif

View File

@ -12,25 +12,12 @@
# include <IO/HTTPHeaderEntries.h>
# include <Storages/StorageS3Settings.h>
# include <aws/core/Version.h>
# include <aws/core/auth/AWSCredentialsProvider.h>
# include <aws/core/auth/AWSCredentialsProviderChain.h>
# include <aws/core/auth/STSCredentialsProvider.h>
# include <aws/core/client/SpecifiedRetryableErrorsRetryStrategy.h>
# include <aws/core/platform/Environment.h>
# include <aws/core/platform/OSVersionInfo.h>
# include <aws/core/utils/json/JsonSerializer.h>
# include <aws/core/utils/logging/LogMacros.h>
# include <aws/core/utils/logging/LogSystemInterface.h>
# include <aws/core/utils/HashingUtils.h>
# include <aws/core/utils/UUID.h>
# include <aws/core/http/HttpClientFactory.h>
# include <IO/S3/PocoHTTPClientFactory.h>
# include <IO/S3/PocoHTTPClient.h>
# include <IO/S3/Client.h>
# include <IO/S3/URI.h>
# include <IO/S3/Requests.h>
# include <IO/S3/Credentials.h>
# include <Common/logger_useful.h>
# include <fstream>
@ -65,750 +52,11 @@ bool S3Exception::isRetryableError() const
}
namespace
{
const char * S3_LOGGER_TAG_NAMES[][2] = {
{"AWSClient", "AWSClient"},
{"AWSAuthV4Signer", "AWSClient (AWSAuthV4Signer)"},
};
const std::pair<DB::LogsLevel, Poco::Message::Priority> & convertLogLevel(Aws::Utils::Logging::LogLevel log_level)
{
/// We map levels to our own logger 1 to 1 except WARN+ levels. In most cases we failover such errors with retries
/// and don't want to see them as Errors in our logs.
static const std::unordered_map<Aws::Utils::Logging::LogLevel, std::pair<DB::LogsLevel, Poco::Message::Priority>> mapping =
{
{Aws::Utils::Logging::LogLevel::Off, {DB::LogsLevel::none, Poco::Message::PRIO_INFORMATION}},
{Aws::Utils::Logging::LogLevel::Fatal, {DB::LogsLevel::information, Poco::Message::PRIO_INFORMATION}},
{Aws::Utils::Logging::LogLevel::Error, {DB::LogsLevel::information, Poco::Message::PRIO_INFORMATION}},
{Aws::Utils::Logging::LogLevel::Warn, {DB::LogsLevel::information, Poco::Message::PRIO_INFORMATION}},
{Aws::Utils::Logging::LogLevel::Info, {DB::LogsLevel::information, Poco::Message::PRIO_INFORMATION}},
{Aws::Utils::Logging::LogLevel::Debug, {DB::LogsLevel::debug, Poco::Message::PRIO_TEST}},
{Aws::Utils::Logging::LogLevel::Trace, {DB::LogsLevel::trace, Poco::Message::PRIO_TEST}},
};
return mapping.at(log_level);
}
class AWSLogger final : public Aws::Utils::Logging::LogSystemInterface
{
public:
explicit AWSLogger(bool enable_s3_requests_logging_)
:enable_s3_requests_logging(enable_s3_requests_logging_)
{
for (auto [tag, name] : S3_LOGGER_TAG_NAMES)
tag_loggers[tag] = &Poco::Logger::get(name);
default_logger = tag_loggers[S3_LOGGER_TAG_NAMES[0][0]];
}
~AWSLogger() final = default;
Aws::Utils::Logging::LogLevel GetLogLevel() const final
{
if (enable_s3_requests_logging)
return Aws::Utils::Logging::LogLevel::Trace;
else
return Aws::Utils::Logging::LogLevel::Info;
}
void Log(Aws::Utils::Logging::LogLevel log_level, const char * tag, const char * format_str, ...) final // NOLINT
{
callLogImpl(log_level, tag, format_str); /// FIXME. Variadic arguments?
}
void LogStream(Aws::Utils::Logging::LogLevel log_level, const char * tag, const Aws::OStringStream & message_stream) final
{
callLogImpl(log_level, tag, message_stream.str().c_str());
}
void callLogImpl(Aws::Utils::Logging::LogLevel log_level, const char * tag, const char * message)
{
const auto & [level, prio] = convertLogLevel(log_level);
if (tag_loggers.contains(tag))
{
LOG_IMPL(tag_loggers[tag], level, prio, fmt::runtime(message));
}
else
{
LOG_IMPL(default_logger, level, prio, "{}: {}", tag, message);
}
}
void Flush() final {}
private:
Poco::Logger * default_logger;
bool enable_s3_requests_logging;
std::unordered_map<String, Poco::Logger *> tag_loggers;
};
class AWSEC2MetadataClient : public Aws::Internal::AWSHttpResourceClient
{
static constexpr char EC2_SECURITY_CREDENTIALS_RESOURCE[] = "/latest/meta-data/iam/security-credentials";
static constexpr char EC2_IMDS_TOKEN_RESOURCE[] = "/latest/api/token";
static constexpr char EC2_IMDS_TOKEN_HEADER[] = "x-aws-ec2-metadata-token";
static constexpr char EC2_IMDS_TOKEN_TTL_DEFAULT_VALUE[] = "21600";
static constexpr char EC2_IMDS_TOKEN_TTL_HEADER[] = "x-aws-ec2-metadata-token-ttl-seconds";
public:
/// See EC2MetadataClient.
explicit AWSEC2MetadataClient(const Aws::Client::ClientConfiguration & client_configuration, const char * endpoint_)
: Aws::Internal::AWSHttpResourceClient(client_configuration)
, endpoint(endpoint_)
, logger(&Poco::Logger::get("AWSEC2InstanceProfileConfigLoader"))
{
}
AWSEC2MetadataClient& operator =(const AWSEC2MetadataClient & rhs) = delete;
AWSEC2MetadataClient(const AWSEC2MetadataClient & rhs) = delete;
AWSEC2MetadataClient& operator =(const AWSEC2MetadataClient && rhs) = delete;
AWSEC2MetadataClient(const AWSEC2MetadataClient && rhs) = delete;
~AWSEC2MetadataClient() override = default;
using Aws::Internal::AWSHttpResourceClient::GetResource;
virtual Aws::String GetResource(const char * resource_path) const
{
return GetResource(endpoint.c_str(), resource_path, nullptr/*authToken*/);
}
virtual Aws::String getDefaultCredentials() const
{
String credentials_string;
{
std::lock_guard locker(token_mutex);
LOG_TRACE(logger, "Getting default credentials for ec2 instance from {}", endpoint);
auto result = GetResourceWithAWSWebServiceResult(endpoint.c_str(), EC2_SECURITY_CREDENTIALS_RESOURCE, nullptr);
credentials_string = result.GetPayload();
if (result.GetResponseCode() == Aws::Http::HttpResponseCode::UNAUTHORIZED)
{
return {};
}
}
String trimmed_credentials_string = Aws::Utils::StringUtils::Trim(credentials_string.c_str());
if (trimmed_credentials_string.empty())
return {};
std::vector<String> security_credentials = Aws::Utils::StringUtils::Split(trimmed_credentials_string, '\n');
LOG_DEBUG(logger, "Calling EC2MetadataService resource, {} returned credential string {}.",
EC2_SECURITY_CREDENTIALS_RESOURCE, trimmed_credentials_string);
if (security_credentials.empty())
{
LOG_WARNING(logger, "Initial call to EC2MetadataService to get credentials failed.");
return {};
}
Aws::StringStream ss;
ss << EC2_SECURITY_CREDENTIALS_RESOURCE << "/" << security_credentials[0];
LOG_DEBUG(logger, "Calling EC2MetadataService resource {}.", ss.str());
return GetResource(ss.str().c_str());
}
static Aws::String awsComputeUserAgentString()
{
Aws::StringStream ss;
ss << "aws-sdk-cpp/" << Aws::Version::GetVersionString() << " " << Aws::OSVersionInfo::ComputeOSVersionString()
<< " " << Aws::Version::GetCompilerVersionString();
return ss.str();
}
virtual Aws::String getDefaultCredentialsSecurely() const
{
String user_agent_string = awsComputeUserAgentString();
String new_token;
{
std::lock_guard locker(token_mutex);
Aws::StringStream ss;
ss << endpoint << EC2_IMDS_TOKEN_RESOURCE;
std::shared_ptr<Aws::Http::HttpRequest> token_request(Aws::Http::CreateHttpRequest(ss.str(), Aws::Http::HttpMethod::HTTP_PUT,
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod));
token_request->SetHeaderValue(EC2_IMDS_TOKEN_TTL_HEADER, EC2_IMDS_TOKEN_TTL_DEFAULT_VALUE);
token_request->SetUserAgent(user_agent_string);
LOG_TRACE(logger, "Calling EC2MetadataService to get token.");
auto result = GetResourceWithAWSWebServiceResult(token_request);
const String & token_string = result.GetPayload();
new_token = Aws::Utils::StringUtils::Trim(token_string.c_str());
if (result.GetResponseCode() == Aws::Http::HttpResponseCode::BAD_REQUEST)
{
return {};
}
else if (result.GetResponseCode() != Aws::Http::HttpResponseCode::OK || new_token.empty())
{
LOG_TRACE(logger, "Calling EC2MetadataService to get token failed, falling back to less secure way.");
return getDefaultCredentials();
}
token = new_token;
}
String url = endpoint + EC2_SECURITY_CREDENTIALS_RESOURCE;
std::shared_ptr<Aws::Http::HttpRequest> profile_request(Aws::Http::CreateHttpRequest(url,
Aws::Http::HttpMethod::HTTP_GET,
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod));
profile_request->SetHeaderValue(EC2_IMDS_TOKEN_HEADER, new_token);
profile_request->SetUserAgent(user_agent_string);
String profile_string = GetResourceWithAWSWebServiceResult(profile_request).GetPayload();
String trimmed_profile_string = Aws::Utils::StringUtils::Trim(profile_string.c_str());
std::vector<String> security_credentials = Aws::Utils::StringUtils::Split(trimmed_profile_string, '\n');
LOG_DEBUG(logger, "Calling EC2MetadataService resource, {} with token returned profile string {}.",
EC2_SECURITY_CREDENTIALS_RESOURCE, trimmed_profile_string);
if (security_credentials.empty())
{
LOG_WARNING(logger, "Calling EC2Metadataservice to get profiles failed.");
return {};
}
Aws::StringStream ss;
ss << endpoint << EC2_SECURITY_CREDENTIALS_RESOURCE << "/" << security_credentials[0];
std::shared_ptr<Aws::Http::HttpRequest> credentials_request(Aws::Http::CreateHttpRequest(ss.str(),
Aws::Http::HttpMethod::HTTP_GET,
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod));
credentials_request->SetHeaderValue(EC2_IMDS_TOKEN_HEADER, new_token);
credentials_request->SetUserAgent(user_agent_string);
LOG_DEBUG(logger, "Calling EC2MetadataService resource {} with token.", ss.str());
return GetResourceWithAWSWebServiceResult(credentials_request).GetPayload();
}
virtual Aws::String getCurrentRegion() const
{
return Aws::Region::AWS_GLOBAL;
}
private:
const Aws::String endpoint;
mutable std::recursive_mutex token_mutex;
mutable Aws::String token;
Poco::Logger * logger;
};
std::shared_ptr<AWSEC2MetadataClient> InitEC2MetadataClient(const Aws::Client::ClientConfiguration & client_configuration)
{
Aws::String ec2_metadata_service_endpoint = Aws::Environment::GetEnv("AWS_EC2_METADATA_SERVICE_ENDPOINT");
auto * logger = &Poco::Logger::get("AWSEC2InstanceProfileConfigLoader");
if (ec2_metadata_service_endpoint.empty())
{
Aws::String ec2_metadata_service_endpoint_mode = Aws::Environment::GetEnv("AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE");
if (ec2_metadata_service_endpoint_mode.length() == 0)
{
ec2_metadata_service_endpoint = "http://169.254.169.254"; //default to IPv4 default endpoint
}
else
{
if (ec2_metadata_service_endpoint_mode.length() == 4)
{
if (Aws::Utils::StringUtils::CaselessCompare(ec2_metadata_service_endpoint_mode.c_str(), "ipv4"))
{
ec2_metadata_service_endpoint = "http://169.254.169.254"; //default to IPv4 default endpoint
}
else if (Aws::Utils::StringUtils::CaselessCompare(ec2_metadata_service_endpoint_mode.c_str(), "ipv6"))
{
ec2_metadata_service_endpoint = "http://[fd00:ec2::254]";
}
else
{
LOG_ERROR(logger, "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE can only be set to ipv4 or ipv6, received: {}", ec2_metadata_service_endpoint_mode);
}
}
else
{
LOG_ERROR(logger, "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE can only be set to ipv4 or ipv6, received: {}", ec2_metadata_service_endpoint_mode);
}
}
}
LOG_INFO(logger, "Using IMDS endpoint: {}", ec2_metadata_service_endpoint);
return std::make_shared<AWSEC2MetadataClient>(client_configuration, ec2_metadata_service_endpoint.c_str());
}
class AWSEC2InstanceProfileConfigLoader : public Aws::Config::AWSProfileConfigLoader
{
public:
explicit AWSEC2InstanceProfileConfigLoader(const std::shared_ptr<AWSEC2MetadataClient> & client_, bool use_secure_pull_)
: client(client_)
, use_secure_pull(use_secure_pull_)
, logger(&Poco::Logger::get("AWSEC2InstanceProfileConfigLoader"))
{
}
~AWSEC2InstanceProfileConfigLoader() override = default;
protected:
bool LoadInternal() override
{
auto credentials_str = use_secure_pull ? client->getDefaultCredentialsSecurely() : client->getDefaultCredentials();
/// See EC2InstanceProfileConfigLoader.
if (credentials_str.empty())
return false;
Aws::Utils::Json::JsonValue credentials_doc(credentials_str);
if (!credentials_doc.WasParseSuccessful())
{
LOG_ERROR(logger, "Failed to parse output from EC2MetadataService.");
return false;
}
String access_key, secret_key, token;
auto credentials_view = credentials_doc.View();
access_key = credentials_view.GetString("AccessKeyId");
LOG_TRACE(logger, "Successfully pulled credentials from EC2MetadataService with access key.");
secret_key = credentials_view.GetString("SecretAccessKey");
token = credentials_view.GetString("Token");
auto region = client->getCurrentRegion();
Aws::Config::Profile profile;
profile.SetCredentials(Aws::Auth::AWSCredentials(access_key, secret_key, token));
profile.SetRegion(region);
profile.SetName(Aws::Config::INSTANCE_PROFILE_KEY);
m_profiles[Aws::Config::INSTANCE_PROFILE_KEY] = profile;
return true;
}
private:
std::shared_ptr<AWSEC2MetadataClient> client;
bool use_secure_pull;
Poco::Logger * logger;
};
class AWSInstanceProfileCredentialsProvider : public Aws::Auth::AWSCredentialsProvider
{
public:
/// See InstanceProfileCredentialsProvider.
explicit AWSInstanceProfileCredentialsProvider(const std::shared_ptr<AWSEC2InstanceProfileConfigLoader> & config_loader)
: ec2_metadata_config_loader(config_loader)
, load_frequency_ms(Aws::Auth::REFRESH_THRESHOLD)
, logger(&Poco::Logger::get("AWSInstanceProfileCredentialsProvider"))
{
LOG_INFO(logger, "Creating Instance with injected EC2MetadataClient and refresh rate.");
}
Aws::Auth::AWSCredentials GetAWSCredentials() override
{
refreshIfExpired();
Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock);
auto profile_it = ec2_metadata_config_loader->GetProfiles().find(Aws::Config::INSTANCE_PROFILE_KEY);
if (profile_it != ec2_metadata_config_loader->GetProfiles().end())
{
return profile_it->second.GetCredentials();
}
return Aws::Auth::AWSCredentials();
}
protected:
void Reload() override
{
LOG_INFO(logger, "Credentials have expired attempting to repull from EC2 Metadata Service.");
ec2_metadata_config_loader->Load();
AWSCredentialsProvider::Reload();
}
private:
void refreshIfExpired()
{
LOG_DEBUG(logger, "Checking if latest credential pull has expired.");
Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock);
if (!IsTimeToRefresh(load_frequency_ms))
{
return;
}
guard.UpgradeToWriterLock();
if (!IsTimeToRefresh(load_frequency_ms)) // double-checked lock to avoid refreshing twice
{
return;
}
Reload();
}
std::shared_ptr<AWSEC2InstanceProfileConfigLoader> ec2_metadata_config_loader;
Int64 load_frequency_ms;
Poco::Logger * logger;
};
class AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider : public Aws::Auth::AWSCredentialsProvider
{
/// See STSAssumeRoleWebIdentityCredentialsProvider.
public:
explicit AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider(DB::S3::PocoHTTPClientConfiguration & aws_client_configuration)
: logger(&Poco::Logger::get("AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider"))
{
// check environment variables
String tmp_region = Aws::Environment::GetEnv("AWS_DEFAULT_REGION");
role_arn = Aws::Environment::GetEnv("AWS_ROLE_ARN");
token_file = Aws::Environment::GetEnv("AWS_WEB_IDENTITY_TOKEN_FILE");
session_name = Aws::Environment::GetEnv("AWS_ROLE_SESSION_NAME");
// check profile_config if either m_roleArn or m_tokenFile is not loaded from environment variable
// region source is not enforced, but we need it to construct sts endpoint, if we can't find from environment, we should check if it's set in config file.
if (role_arn.empty() || token_file.empty() || tmp_region.empty())
{
auto profile = Aws::Config::GetCachedConfigProfile(Aws::Auth::GetConfigProfileName());
if (tmp_region.empty())
{
tmp_region = profile.GetRegion();
}
// If either of these two were not found from environment, use whatever found for all three in config file
if (role_arn.empty() || token_file.empty())
{
role_arn = profile.GetRoleArn();
token_file = profile.GetValue("web_identity_token_file");
session_name = profile.GetValue("role_session_name");
}
}
if (token_file.empty())
{
LOG_WARNING(logger, "Token file must be specified to use STS AssumeRole web identity creds provider.");
return; // No need to do further constructing
}
else
{
LOG_DEBUG(logger, "Resolved token_file from profile_config or environment variable to be {}", token_file);
}
if (role_arn.empty())
{
LOG_WARNING(logger, "RoleArn must be specified to use STS AssumeRole web identity creds provider.");
return; // No need to do further constructing
}
else
{
LOG_DEBUG(logger, "Resolved role_arn from profile_config or environment variable to be {}", role_arn);
}
if (tmp_region.empty())
{
tmp_region = Aws::Region::US_EAST_1;
}
else
{
LOG_DEBUG(logger, "Resolved region from profile_config or environment variable to be {}", tmp_region);
}
if (session_name.empty())
{
session_name = Aws::Utils::UUID::RandomUUID();
}
else
{
LOG_DEBUG(logger, "Resolved session_name from profile_config or environment variable to be {}", session_name);
}
aws_client_configuration.scheme = Aws::Http::Scheme::HTTPS;
aws_client_configuration.region = tmp_region;
std::vector<String> retryable_errors;
retryable_errors.push_back("IDPCommunicationError");
retryable_errors.push_back("InvalidIdentityToken");
aws_client_configuration.retryStrategy = std::make_shared<Aws::Client::SpecifiedRetryableErrorsRetryStrategy>(
retryable_errors, /* maxRetries = */3);
client = std::make_unique<Aws::Internal::STSCredentialsClient>(aws_client_configuration);
initialized = true;
LOG_INFO(logger, "Creating STS AssumeRole with web identity creds provider.");
}
Aws::Auth::AWSCredentials GetAWSCredentials() override
{
// A valid client means required information like role arn and token file were constructed correctly.
// We can use this provider to load creds, otherwise, we can just return empty creds.
if (!initialized)
{
return Aws::Auth::AWSCredentials();
}
refreshIfExpired();
Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock);
return credentials;
}
protected:
void Reload() override
{
LOG_INFO(logger, "Credentials have expired, attempting to renew from STS.");
std::ifstream token_stream(token_file.data());
if (token_stream)
{
String token_string((std::istreambuf_iterator<char>(token_stream)), std::istreambuf_iterator<char>());
token = token_string;
}
else
{
LOG_INFO(logger, "Can't open token file: {}", token_file);
return;
}
Aws::Internal::STSCredentialsClient::STSAssumeRoleWithWebIdentityRequest request{session_name, role_arn, token};
auto result = client->GetAssumeRoleWithWebIdentityCredentials(request);
LOG_TRACE(logger, "Successfully retrieved credentials.");
credentials = result.creds;
}
private:
void refreshIfExpired()
{
Aws::Utils::Threading::ReaderLockGuard guard(m_reloadLock);
if (!credentials.IsExpiredOrEmpty())
{
return;
}
guard.UpgradeToWriterLock();
if (!credentials.IsExpiredOrEmpty()) // double-checked lock to avoid refreshing twice
{
return;
}
Reload();
}
std::unique_ptr<Aws::Internal::STSCredentialsClient> client;
Aws::Auth::AWSCredentials credentials;
Aws::String role_arn;
Aws::String token_file;
Aws::String session_name;
Aws::String token;
bool initialized = false;
Poco::Logger * logger;
};
class S3CredentialsProviderChain : public Aws::Auth::AWSCredentialsProviderChain
{
public:
S3CredentialsProviderChain(const DB::S3::PocoHTTPClientConfiguration & configuration, const Aws::Auth::AWSCredentials & credentials, bool use_environment_credentials, bool use_insecure_imds_request)
{
auto * logger = &Poco::Logger::get("S3CredentialsProviderChain");
/// add explicit credentials to the front of the chain
/// because it's manually defined by the user
if (!credentials.IsEmpty())
{
AddProvider(std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(credentials));
return;
}
if (use_environment_credentials)
{
static const char AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI[] = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI";
static const char AWS_ECS_CONTAINER_CREDENTIALS_FULL_URI[] = "AWS_CONTAINER_CREDENTIALS_FULL_URI";
static const char AWS_ECS_CONTAINER_AUTHORIZATION_TOKEN[] = "AWS_CONTAINER_AUTHORIZATION_TOKEN";
static const char AWS_EC2_METADATA_DISABLED[] = "AWS_EC2_METADATA_DISABLED";
/// The only difference from DefaultAWSCredentialsProviderChain::DefaultAWSCredentialsProviderChain()
/// is that this chain uses custom ClientConfiguration. Also we removed process provider because it's useless in our case.
///
/// AWS API tries credentials providers one by one. Some of providers (like ProfileConfigFileAWSCredentialsProvider) can be
/// quite verbose even if nobody configured them. So we use our provider first and only after it use default providers.
{
DB::S3::PocoHTTPClientConfiguration aws_client_configuration = DB::S3::ClientFactory::instance().createClientConfiguration(
configuration.region,
configuration.remote_host_filter,
configuration.s3_max_redirects,
configuration.enable_s3_requests_logging,
configuration.for_disk_s3,
configuration.get_request_throttler,
configuration.put_request_throttler);
AddProvider(std::make_shared<AwsAuthSTSAssumeRoleWebIdentityCredentialsProvider>(aws_client_configuration));
}
AddProvider(std::make_shared<Aws::Auth::EnvironmentAWSCredentialsProvider>());
/// ECS TaskRole Credentials only available when ENVIRONMENT VARIABLE is set.
const auto relative_uri = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI);
LOG_DEBUG(logger, "The environment variable value {} is {}", AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI,
relative_uri);
const auto absolute_uri = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_CREDENTIALS_FULL_URI);
LOG_DEBUG(logger, "The environment variable value {} is {}", AWS_ECS_CONTAINER_CREDENTIALS_FULL_URI,
absolute_uri);
const auto ec2_metadata_disabled = Aws::Environment::GetEnv(AWS_EC2_METADATA_DISABLED);
LOG_DEBUG(logger, "The environment variable value {} is {}", AWS_EC2_METADATA_DISABLED,
ec2_metadata_disabled);
if (!relative_uri.empty())
{
AddProvider(std::make_shared<Aws::Auth::TaskRoleCredentialsProvider>(relative_uri.c_str()));
LOG_INFO(logger, "Added ECS metadata service credentials provider with relative path: [{}] to the provider chain.",
relative_uri);
}
else if (!absolute_uri.empty())
{
const auto token = Aws::Environment::GetEnv(AWS_ECS_CONTAINER_AUTHORIZATION_TOKEN);
AddProvider(std::make_shared<Aws::Auth::TaskRoleCredentialsProvider>(absolute_uri.c_str(), token.c_str()));
/// DO NOT log the value of the authorization token for security purposes.
LOG_INFO(logger, "Added ECS credentials provider with URI: [{}] to the provider chain with a{} authorization token.",
absolute_uri, token.empty() ? "n empty" : " non-empty");
}
else if (Aws::Utils::StringUtils::ToLower(ec2_metadata_disabled.c_str()) != "true")
{
DB::S3::PocoHTTPClientConfiguration aws_client_configuration = DB::S3::ClientFactory::instance().createClientConfiguration(
configuration.region,
configuration.remote_host_filter,
configuration.s3_max_redirects,
configuration.enable_s3_requests_logging,
configuration.for_disk_s3,
configuration.get_request_throttler,
configuration.put_request_throttler);
/// See MakeDefaultHttpResourceClientConfiguration().
/// This is part of EC2 metadata client, but unfortunately it can't be accessed from outside
/// of contrib/aws/aws-cpp-sdk-core/source/internal/AWSHttpResourceClient.cpp
aws_client_configuration.maxConnections = 2;
aws_client_configuration.scheme = Aws::Http::Scheme::HTTP;
/// Explicitly set the proxy settings to empty/zero to avoid relying on defaults that could potentially change
/// in the future.
aws_client_configuration.proxyHost = "";
aws_client_configuration.proxyUserName = "";
aws_client_configuration.proxyPassword = "";
aws_client_configuration.proxyPort = 0;
/// EC2MetadataService throttles by delaying the response so the service client should set a large read timeout.
/// EC2MetadataService delay is in order of seconds so it only make sense to retry after a couple of seconds.
aws_client_configuration.connectTimeoutMs = 1000;
aws_client_configuration.requestTimeoutMs = 1000;
aws_client_configuration.retryStrategy = std::make_shared<Aws::Client::DefaultRetryStrategy>(1, 1000);
auto ec2_metadata_client = InitEC2MetadataClient(aws_client_configuration);
auto config_loader = std::make_shared<AWSEC2InstanceProfileConfigLoader>(ec2_metadata_client, !use_insecure_imds_request);
AddProvider(std::make_shared<AWSInstanceProfileCredentialsProvider>(config_loader));
LOG_INFO(logger, "Added EC2 metadata service credentials provider to the provider chain.");
}
}
/// Quite verbose provider (argues if file with credentials doesn't exist) so iut's the last one
/// in chain.
AddProvider(std::make_shared<Aws::Auth::ProfileConfigFileAWSCredentialsProvider>());
}
};
}
namespace DB
{
namespace ErrorCodes
namespace DB::ErrorCodes
{
extern const int S3_ERROR;
}
namespace S3
{
ClientFactory::ClientFactory()
{
aws_options = Aws::SDKOptions{};
Aws::InitAPI(aws_options);
Aws::Utils::Logging::InitializeAWSLogging(std::make_shared<AWSLogger>(false));
Aws::Http::SetHttpClientFactory(std::make_shared<PocoHTTPClientFactory>());
}
ClientFactory::~ClientFactory()
{
Aws::Utils::Logging::ShutdownAWSLogging();
Aws::ShutdownAPI(aws_options);
}
ClientFactory & ClientFactory::instance()
{
static ClientFactory ret;
return ret;
}
std::unique_ptr<S3::Client> ClientFactory::create( // NOLINT
const PocoHTTPClientConfiguration & cfg_,
bool is_virtual_hosted_style,
const String & access_key_id,
const String & secret_access_key,
const String & server_side_encryption_customer_key_base64,
HTTPHeaderEntries headers,
bool use_environment_credentials,
bool use_insecure_imds_request)
{
PocoHTTPClientConfiguration client_configuration = cfg_;
client_configuration.updateSchemeAndRegion();
if (!server_side_encryption_customer_key_base64.empty())
{
/// See Client::GeneratePresignedUrlWithSSEC().
headers.push_back({Aws::S3::SSEHeaders::SERVER_SIDE_ENCRYPTION_CUSTOMER_ALGORITHM,
Aws::S3::Model::ServerSideEncryptionMapper::GetNameForServerSideEncryption(Aws::S3::Model::ServerSideEncryption::AES256)});
headers.push_back({Aws::S3::SSEHeaders::SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY,
server_side_encryption_customer_key_base64});
Aws::Utils::ByteBuffer buffer = Aws::Utils::HashingUtils::Base64Decode(server_side_encryption_customer_key_base64);
String str_buffer(reinterpret_cast<char *>(buffer.GetUnderlyingData()), buffer.GetLength());
headers.push_back({Aws::S3::SSEHeaders::SERVER_SIDE_ENCRYPTION_CUSTOMER_KEY_MD5,
Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateMD5(str_buffer))});
}
client_configuration.extra_headers = std::move(headers);
Aws::Auth::AWSCredentials credentials(access_key_id, secret_access_key);
auto credentials_provider = std::make_shared<S3CredentialsProviderChain>(
client_configuration,
std::move(credentials),
use_environment_credentials,
use_insecure_imds_request);
client_configuration.retryStrategy = std::make_shared<Client::RetryStrategy>(std::move(client_configuration.retryStrategy));
return Client::create(
client_configuration.s3_max_redirects,
std::move(credentials_provider),
std::move(client_configuration), // Client configuration.
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
is_virtual_hosted_style || client_configuration.endpointOverride.empty() /// Use virtual addressing if endpoint is not specified.
);
}
PocoHTTPClientConfiguration ClientFactory::createClientConfiguration( // NOLINT
const String & force_region,
const RemoteHostFilter & remote_host_filter,
unsigned int s3_max_redirects,
bool enable_s3_requests_logging,
bool for_disk_s3,
const ThrottlerPtr & get_request_throttler,
const ThrottlerPtr & put_request_throttler)
{
return PocoHTTPClientConfiguration(
force_region,
remote_host_filter,
s3_max_redirects,
enable_s3_requests_logging,
for_disk_s3,
get_request_throttler,
put_request_throttler);
}
}
}
#endif
namespace DB

View File

@ -61,43 +61,6 @@ private:
};
}
namespace DB::S3
{
class ClientFactory
{
public:
~ClientFactory();
static ClientFactory & instance();
std::unique_ptr<S3::Client> create(
const PocoHTTPClientConfiguration & cfg,
bool is_virtual_hosted_style,
const String & access_key_id,
const String & secret_access_key,
const String & server_side_encryption_customer_key_base64,
HTTPHeaderEntries headers,
bool use_environment_credentials,
bool use_insecure_imds_request);
PocoHTTPClientConfiguration createClientConfiguration(
const String & force_region,
const RemoteHostFilter & remote_host_filter,
unsigned int s3_max_redirects,
bool enable_s3_requests_logging,
bool for_disk_s3,
const ThrottlerPtr & get_request_throttler,
const ThrottlerPtr & put_request_throttler);
private:
ClientFactory();
Aws::SDKOptions aws_options;
std::atomic<bool> s3_requests_logging_enabled;
};
}
#endif
namespace Poco::Util

View File

@ -448,10 +448,6 @@ InterpreterSelectQuery::InterpreterSelectQuery(
}
}
/// FIXME: Memory bound aggregation may cause another reading algorithm to be used on remote replicas
if (settings.allow_experimental_parallel_reading_from_replicas && settings.enable_memory_bound_merging_of_aggregation_results)
context->setSetting("enable_memory_bound_merging_of_aggregation_results", false);
if (joined_tables.tablesCount() > 1 && settings.allow_experimental_parallel_reading_from_replicas)
{
LOG_WARNING(log, "Joins are not supported with parallel replicas. Query will be executed without using them.");
@ -2520,24 +2516,8 @@ void InterpreterSelectQuery::executeAggregation(QueryPlan & query_plan, const Ac
if (!group_by_info && settings.force_aggregation_in_order)
{
/// Not the most optimal implementation here, but this branch handles very marginal case.
group_by_sort_description = getSortDescriptionFromGroupBy(getSelectQuery());
auto sorting_step = std::make_unique<SortingStep>(
query_plan.getCurrentDataStream(),
group_by_sort_description,
0 /* LIMIT */,
SortingStep::Settings(*context),
settings.optimize_sorting_by_input_stream_properties);
sorting_step->setStepDescription("Enforced sorting for aggregation in order");
query_plan.addStep(std::move(sorting_step));
group_by_info = std::make_shared<InputOrderInfo>(
group_by_sort_description, group_by_sort_description.size(), 1 /* direction */, 0 /* limit */);
sort_description_for_merging = group_by_info->sort_description_for_merging;
sort_description_for_merging = group_by_sort_description;
}
auto merge_threads = max_streams;
@ -2564,7 +2544,8 @@ void InterpreterSelectQuery::executeAggregation(QueryPlan & query_plan, const Ac
std::move(sort_description_for_merging),
std::move(group_by_sort_description),
should_produce_results_in_order_of_bucket_number,
settings.enable_memory_bound_merging_of_aggregation_results);
settings.enable_memory_bound_merging_of_aggregation_results,
!group_by_info && settings.force_aggregation_in_order);
query_plan.addStep(std::move(aggregating_step));
}

View File

@ -313,7 +313,8 @@ void addAggregationStep(QueryPlan & query_plan,
std::move(sort_description_for_merging),
std::move(group_by_sort_description),
query_analysis_result.aggregation_should_produce_results_in_order_of_bucket_number,
settings.enable_memory_bound_merging_of_aggregation_results);
settings.enable_memory_bound_merging_of_aggregation_results,
settings.force_aggregation_in_order);
query_plan.addStep(std::move(aggregating_step));
}

View File

@ -5,12 +5,14 @@
#include <DataTypes/DataTypeFixedString.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <IO/Operators.h>
#include <Interpreters/Aggregator.h>
#include <Interpreters/Context.h>
#include <Processors/Merges/AggregatingSortedTransform.h>
#include <Processors/Merges/FinishAggregatingInOrderTransform.h>
#include <Processors/QueryPlan/AggregatingStep.h>
#include <Processors/QueryPlan/IQueryPlanStep.h>
#include <Processors/QueryPlan/SortingStep.h>
#include <Processors/Transforms/AggregatingInOrderTransform.h>
#include <Processors/Transforms/AggregatingTransform.h>
#include <Processors/Transforms/CopyTransform.h>
@ -18,7 +20,6 @@
#include <Processors/Transforms/MemoryBoundMerging.h>
#include <Processors/Transforms/MergingAggregatedMemoryEfficientTransform.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <IO/Operators.h>
#include <Common/JSONBuilder.h>
namespace DB
@ -101,7 +102,8 @@ AggregatingStep::AggregatingStep(
SortDescription sort_description_for_merging_,
SortDescription group_by_sort_description_,
bool should_produce_results_in_order_of_bucket_number_,
bool memory_bound_merging_of_aggregation_results_enabled_)
bool memory_bound_merging_of_aggregation_results_enabled_,
bool explicit_sorting_required_for_aggregation_in_order_)
: ITransformingStep(
input_stream_,
appendGroupingColumn(params_.getHeader(input_stream_.header, final_), params_.keys, !grouping_sets_params_.empty(), group_by_use_nulls_),
@ -120,11 +122,13 @@ AggregatingStep::AggregatingStep(
, group_by_sort_description(std::move(group_by_sort_description_))
, should_produce_results_in_order_of_bucket_number(should_produce_results_in_order_of_bucket_number_)
, memory_bound_merging_of_aggregation_results_enabled(memory_bound_merging_of_aggregation_results_enabled_)
, explicit_sorting_required_for_aggregation_in_order(explicit_sorting_required_for_aggregation_in_order_)
{
if (memoryBoundMergingWillBeUsed())
{
output_stream->sort_description = group_by_sort_description;
output_stream->sort_scope = DataStream::SortScope::Global;
output_stream->has_single_port = true;
}
}
@ -139,6 +143,8 @@ void AggregatingStep::applyOrder(SortDescription sort_description_for_merging_,
output_stream->sort_scope = DataStream::SortScope::Global;
output_stream->has_single_port = true;
}
explicit_sorting_required_for_aggregation_in_order = false;
}
void AggregatingStep::transformPipeline(QueryPipelineBuilder & pipeline, const BuildQueryPipelineSettings & settings)
@ -333,6 +339,15 @@ void AggregatingStep::transformPipeline(QueryPipelineBuilder & pipeline, const B
if (!sort_description_for_merging.empty())
{
/// We don't rely here on input_stream.sort_description because it is not correctly propagated for now in all cases
/// see https://github.com/ClickHouse/ClickHouse/pull/45892#discussion_r1094503048
if (explicit_sorting_required_for_aggregation_in_order)
{
/// We don't really care about optimality of this sorting, because it's required only in fairly marginal cases.
SortingStep::fullSortStreams(
pipeline, SortingStep::Settings(params.max_block_size), sort_description_for_merging, 0 /* limit */);
}
if (pipeline.getNumStreams() > 1)
{
/** The pipeline is the following:

View File

@ -40,7 +40,8 @@ public:
SortDescription sort_description_for_merging_,
SortDescription group_by_sort_description_,
bool should_produce_results_in_order_of_bucket_number_,
bool memory_bound_merging_of_aggregation_results_enabled_);
bool memory_bound_merging_of_aggregation_results_enabled_,
bool explicit_sorting_required_for_aggregation_in_order_);
static Block appendGroupingColumn(Block block, const Names & keys, bool has_grouping, bool use_nulls);
@ -56,6 +57,7 @@ public:
const Aggregator::Params & getParams() const { return params; }
bool inOrder() const { return !sort_description_for_merging.empty(); }
bool explicitSortingRequired() const { return explicit_sorting_required_for_aggregation_in_order; }
bool isGroupingSets() const { return !grouping_sets_params.empty(); }
void applyOrder(SortDescription sort_description_for_merging_, SortDescription group_by_sort_description_);
bool memoryBoundMergingWillBeUsed() const;
@ -84,6 +86,7 @@ private:
/// These settings are used to determine if we should resize pipeline to 1 at the end.
bool should_produce_results_in_order_of_bucket_number;
bool memory_bound_merging_of_aggregation_results_enabled;
bool explicit_sorting_required_for_aggregation_in_order;
Processors aggregating_in_order;
Processors aggregating_sorted;

View File

@ -73,12 +73,15 @@ MergingAggregatedStep::MergingAggregatedStep(
}
}
void MergingAggregatedStep::updateInputSortDescription(SortDescription sort_description, DataStream::SortScope sort_scope)
void MergingAggregatedStep::applyOrder(SortDescription sort_description, DataStream::SortScope sort_scope)
{
auto & input_stream = input_streams.front();
input_stream.sort_scope = sort_scope;
input_stream.sort_description = sort_description;
/// Columns might be reordered during optimisation, so we better to update sort description.
group_by_sort_description = std::move(sort_description);
if (memoryBoundMergingWillBeUsed() && should_produce_results_in_order_of_bucket_number)
{
output_stream->sort_description = group_by_sort_description;

View File

@ -33,7 +33,7 @@ public:
void describeActions(JSONBuilder::JSONMap & map) const override;
void describeActions(FormatSettings & settings) const override;
void updateInputSortDescription(SortDescription input_sort_description, DataStream::SortScope sort_scope);
void applyOrder(SortDescription input_sort_description, DataStream::SortScope sort_scope);
bool memoryBoundMergingWillBeUsed() const;
@ -48,7 +48,7 @@ private:
size_t memory_efficient_merge_threads;
const size_t max_block_size;
const size_t memory_bound_merging_max_block_bytes;
const SortDescription group_by_sort_description;
SortDescription group_by_sort_description;
/// These settings are used to determine if we should resize pipeline to 1 at the end.
const bool should_produce_results_in_order_of_bucket_number;

View File

@ -88,7 +88,7 @@ void enableMemoryBoundMerging(QueryPlan::Node & node, QueryPlan::Nodes &)
reading->enforceAggregationInOrder();
}
root_mergine_aggeregated->updateInputSortDescription(sort_description, DataStream::SortScope::Stream);
root_mergine_aggeregated->applyOrder(sort_description, DataStream::SortScope::Stream);
}
}

View File

@ -1200,7 +1200,7 @@ void optimizeAggregationInOrder(QueryPlan::Node & node, QueryPlan::Nodes &)
if (!aggregating)
return;
if (aggregating->inOrder() || aggregating->isGroupingSets())
if ((aggregating->inOrder() && !aggregating->explicitSortingRequired()) || aggregating->isGroupingSets())
return;
/// TODO: maybe add support for UNION later.

View File

@ -182,7 +182,8 @@ void SortingStep::mergingSorted(QueryPipelineBuilder & pipeline, const SortDescr
}
}
void SortingStep::mergeSorting(QueryPipelineBuilder & pipeline, const SortDescription & result_sort_desc, UInt64 limit_)
void SortingStep::mergeSorting(
QueryPipelineBuilder & pipeline, const Settings & sort_settings, const SortDescription & result_sort_desc, UInt64 limit_)
{
bool increase_sort_description_compile_attempts = true;
@ -200,6 +201,10 @@ void SortingStep::mergeSorting(QueryPipelineBuilder & pipeline, const SortDescri
if (increase_sort_description_compile_attempts)
increase_sort_description_compile_attempts = false;
auto tmp_data_on_disk = sort_settings.tmp_data
? std::make_unique<TemporaryDataOnDisk>(sort_settings.tmp_data, CurrentMetrics::TemporaryFilesForSort)
: std::unique_ptr<TemporaryDataOnDisk>();
return std::make_shared<MergeSortingTransform>(
header,
result_sort_desc,
@ -209,12 +214,17 @@ void SortingStep::mergeSorting(QueryPipelineBuilder & pipeline, const SortDescri
sort_settings.max_bytes_before_remerge / pipeline.getNumStreams(),
sort_settings.remerge_lowered_memory_bytes_ratio,
sort_settings.max_bytes_before_external_sort,
std::make_unique<TemporaryDataOnDisk>(sort_settings.tmp_data, CurrentMetrics::TemporaryFilesForSort),
std::move(tmp_data_on_disk),
sort_settings.min_free_disk_space);
});
}
void SortingStep::fullSort(QueryPipelineBuilder & pipeline, const SortDescription & result_sort_desc, const UInt64 limit_, const bool skip_partial_sort)
void SortingStep::fullSortStreams(
QueryPipelineBuilder & pipeline,
const Settings & sort_settings,
const SortDescription & result_sort_desc,
const UInt64 limit_,
const bool skip_partial_sort)
{
if (!skip_partial_sort || limit_)
{
@ -241,7 +251,13 @@ void SortingStep::fullSort(QueryPipelineBuilder & pipeline, const SortDescriptio
});
}
mergeSorting(pipeline, result_sort_desc, limit_);
mergeSorting(pipeline, sort_settings, result_sort_desc, limit_);
}
void SortingStep::fullSort(
QueryPipelineBuilder & pipeline, const SortDescription & result_sort_desc, const UInt64 limit_, const bool skip_partial_sort)
{
fullSortStreams(pipeline, sort_settings, result_sort_desc, limit_, skip_partial_sort);
/// If there are several streams, then we merge them into one
if (pipeline.getNumStreams() > 1)

View File

@ -73,11 +73,20 @@ public:
Type getType() const { return type; }
const Settings & getSettings() const { return sort_settings; }
static void fullSortStreams(
QueryPipelineBuilder & pipeline,
const Settings & sort_settings,
const SortDescription & result_sort_desc,
UInt64 limit_,
bool skip_partial_sort = false);
private:
void updateOutputStream() override;
static void
mergeSorting(QueryPipelineBuilder & pipeline, const Settings & sort_settings, const SortDescription & result_sort_desc, UInt64 limit_);
void mergingSorted(QueryPipelineBuilder & pipeline, const SortDescription & result_sort_desc, UInt64 limit_);
void mergeSorting(QueryPipelineBuilder & pipeline, const SortDescription & result_sort_desc, UInt64 limit_);
void finishSorting(
QueryPipelineBuilder & pipeline, const SortDescription & input_sort_desc, const SortDescription & result_sort_desc, UInt64 limit_);
void fullSort(

View File

@ -4,11 +4,22 @@
namespace DB
{
IMessageProducer::IMessageProducer(Poco::Logger * log_) : log(log_)
{
}
void AsynchronousMessageProducer::start(const ContextPtr & context)
{
LOG_TEST(log, "Executing startup");
initialize();
producing_task = context->getSchedulePool().createTask(getProducingTaskName(), [this]
{
LOG_TEST(log, "Starting producing task loop");
scheduled.store(true);
scheduled.notify_one();
startProducingTaskLoop();
});
producing_task->activateAndSchedule();
@ -20,8 +31,17 @@ void AsynchronousMessageProducer::finish()
if (finished.exchange(true))
return;
LOG_TEST(log, "Executing shutdown");
/// It is possible that the task with a producer loop haven't been started yet
/// while we have non empty payloads queue.
/// If we deactivate it here, the messages will never be sent,
/// as the producer loop will never start.
scheduled.wait(false);
/// Tell the task that it should shutdown, but not immediately,
/// it will finish executing current tasks nevertheless.
stopProducingTask();
/// Deactivate producing task and wait until it's finished.
/// Wait for the producer task to finish.
producing_task->deactivate();
finishImpl();
}

View File

@ -6,6 +6,8 @@
#include <Interpreters/Context.h>
#include <Core/BackgroundSchedulePool.h>
namespace Poco { class Logger; }
namespace DB
{
@ -14,6 +16,8 @@ namespace DB
class IMessageProducer
{
public:
explicit IMessageProducer(Poco::Logger * log_);
/// Do some preparations.
virtual void start(const ContextPtr & context) = 0;
@ -24,12 +28,17 @@ public:
virtual void finish() = 0;
virtual ~IMessageProducer() = default;
protected:
Poco::Logger * log;
};
/// Implements interface for concurrent message producing.
class AsynchronousMessageProducer : public IMessageProducer
{
public:
explicit AsynchronousMessageProducer(Poco::Logger * log_) : IMessageProducer(log_) {}
/// Create and schedule task in BackgroundSchedulePool that will produce messages.
void start(const ContextPtr & context) override;
@ -58,6 +67,8 @@ private:
std::atomic<bool> finished = false;
BackgroundSchedulePool::TaskHolder producing_task;
std::atomic<bool> scheduled;
};

View File

@ -18,7 +18,11 @@ namespace DB
KafkaProducer::KafkaProducer(
ProducerPtr producer_, const std::string & topic_, std::chrono::milliseconds poll_timeout, std::atomic<bool> & shutdown_called_, const Block & header)
: producer(producer_), topic(topic_), timeout(poll_timeout), shutdown_called(shutdown_called_)
: IMessageProducer(&Poco::Logger::get("KafkaProducer"))
, producer(producer_)
, topic(topic_)
, timeout(poll_timeout)
, shutdown_called(shutdown_called_)
{
if (header.has("_key"))
{

View File

@ -413,7 +413,8 @@ QueryPlanPtr MergeTreeDataSelectExecutor::read(
std::move(sort_description_for_merging),
std::move(group_by_sort_description),
should_produce_results_in_order_of_bucket_number,
settings.enable_memory_bound_merging_of_aggregation_results);
settings.enable_memory_bound_merging_of_aggregation_results,
!group_by_info && settings.force_aggregation_in_order);
query_plan->addStep(std::move(aggregating_step));
};

View File

@ -497,7 +497,15 @@ static NameSet collectFilesToSkip(
auto source_updated_stream_counts = getStreamCounts(source_part, updated_header.getNames());
auto new_updated_stream_counts = getStreamCounts(new_part, updated_header.getNames());
/// Skip updated files
/// Skip all modified files in new part.
for (const auto & [stream_name, _] : new_updated_stream_counts)
{
files_to_skip.insert(stream_name + ".bin");
files_to_skip.insert(stream_name + mrk_extension);
}
/// Skip files that we read from source part and do not write in new part.
/// E.g. ALTER MODIFY from LowCardinality(String) to String.
for (const auto & [stream_name, _] : source_updated_stream_counts)
{
/// If we read shared stream and do not write it

View File

@ -24,11 +24,11 @@ NATSProducer::NATSProducer(
const String & subject_,
std::atomic<bool> & shutdown_called_,
Poco::Logger * log_)
: connection(configuration_, log_)
: AsynchronousMessageProducer(log_)
, connection(configuration_, log_)
, subject(subject_)
, shutdown_called(shutdown_called_)
, payloads(BATCH)
, log(log_)
{
}

View File

@ -50,8 +50,6 @@ private:
* - payloads are pushed to queue in countRow and popped by another thread in writingFunc, each payload gets into queue only once
*/
ConcurrentBoundedQueue<String> payloads;
Poco::Logger * log;
};
}

View File

@ -31,7 +31,8 @@ RabbitMQProducer::RabbitMQProducer(
const bool persistent_,
std::atomic<bool> & shutdown_called_,
Poco::Logger * log_)
: connection(configuration_, log_)
: AsynchronousMessageProducer(log_)
, connection(configuration_, log_)
, routing_keys(routing_keys_)
, exchange_name(exchange_name_)
, exchange_type(exchange_type_)
@ -40,7 +41,6 @@ RabbitMQProducer::RabbitMQProducer(
, shutdown_called(shutdown_called_)
, payloads(BATCH)
, returned(RETURNED_LIMIT)
, log(log_)
{
}

View File

@ -103,8 +103,6 @@ private:
/// Record of pending acknowledgements from the server; its size never exceeds size of returned.queue
std::map<UInt64, Payload> delivery_record;
Poco::Logger * log;
};
}

View File

@ -1299,13 +1299,15 @@ StorageS3::Configuration StorageS3::getConfiguration(ASTs & engine_args, Context
/// S3('url')
/// S3('url', 'format')
/// S3('url', 'format', 'compression')
/// S3('url', 'aws_access_key_id', 'aws_secret_access_key')
/// S3('url', 'aws_access_key_id', 'aws_secret_access_key', 'format')
/// S3('url', 'aws_access_key_id', 'aws_secret_access_key', 'format', 'compression')
/// with optional headers() function
if (engine_args.empty() || engine_args.size() > 5)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Storage S3 requires 1 to 5 arguments: "
"url, [access_key_id, secret_access_key], name of used format and [compression_method].");
"url, [access_key_id, secret_access_key], name of used format and [compression_method]");
auto * header_it = StorageURL::collectHeaders(engine_args, configuration.headers_from_ast, local_context);
if (header_it != engine_args.end())
@ -1314,24 +1316,49 @@ StorageS3::Configuration StorageS3::getConfiguration(ASTs & engine_args, Context
for (auto & engine_arg : engine_args)
engine_arg = evaluateConstantExpressionOrIdentifierAsLiteral(engine_arg, local_context);
configuration.url = S3::URI(checkAndGetLiteralArgument<String>(engine_args[0], "url"));
if (engine_args.size() >= 4)
/// Size -> argument indexes
static std::unordered_map<size_t, std::unordered_map<std::string_view, size_t>> size_to_engine_args
{
configuration.auth_settings.access_key_id = checkAndGetLiteralArgument<String>(engine_args[1], "access_key_id");
configuration.auth_settings.secret_access_key = checkAndGetLiteralArgument<String>(engine_args[2], "secret_access_key");
{1, {{}}},
{2, {{"format", 1}}},
{4, {{"access_key_id", 1}, {"secret_access_key", 2}, {"format", 3}}},
{5, {{"access_key_id", 1}, {"secret_access_key", 2}, {"format", 3}, {"compression_method", 4}}}
};
std::unordered_map<std::string_view, size_t> engine_args_to_idx;
/// For 3 arguments we support 2 possible variants:
/// s3(source, format, compression_method) and s3(source, access_key_id, access_key_id)
/// We can distinguish them by looking at the 2-nd argument: check if it's a format name or not.
if (engine_args.size() == 3)
{
auto second_arg = checkAndGetLiteralArgument<String>(engine_args[1], "format/access_key_id");
if (second_arg == "auto" || FormatFactory::instance().getAllFormats().contains(second_arg))
engine_args_to_idx = {{"format", 1}, {"compression_method", 2}};
else
engine_args_to_idx = {{"access_key_id", 1}, {"secret_access_key", 2}};
}
else
{
engine_args_to_idx = size_to_engine_args[engine_args.size()];
}
if (engine_args.size() == 3 || engine_args.size() == 5)
{
configuration.compression_method = checkAndGetLiteralArgument<String>(engine_args.back(), "compression_method");
configuration.format = checkAndGetLiteralArgument<String>(engine_args[engine_args.size() - 2], "format");
}
else if (engine_args.size() != 1)
{
configuration.compression_method = "auto";
configuration.format = checkAndGetLiteralArgument<String>(engine_args.back(), "format");
}
/// This argument is always the first
configuration.url = S3::URI(checkAndGetLiteralArgument<String>(engine_args[0], "url"));
if (engine_args_to_idx.contains("format"))
configuration.format = checkAndGetLiteralArgument<String>(engine_args[engine_args_to_idx["format"]], "format");
if (engine_args_to_idx.contains("compression_method"))
configuration.compression_method = checkAndGetLiteralArgument<String>(engine_args[engine_args_to_idx["compression_method"]], "compression_method");
if (engine_args_to_idx.contains("access_key_id"))
configuration.auth_settings.access_key_id = checkAndGetLiteralArgument<String>(engine_args[engine_args_to_idx["access_key_id"]], "access_key_id");
if (engine_args_to_idx.contains("secret_access_key"))
configuration.auth_settings.secret_access_key = checkAndGetLiteralArgument<String>(engine_args[engine_args_to_idx["secret_access_key"]], "secret_access_key");
}
configuration.static_configuration = !configuration.auth_settings.access_key_id.empty();
if (configuration.format == "auto")

View File

@ -48,7 +48,7 @@ void TableFunctionS3::parseArgumentsImpl(const String & error_message, ASTs & ar
arg = evaluateConstantExpressionOrIdentifierAsLiteral(arg, context);
/// Size -> argument indexes
static auto size_to_args = std::map<size_t, std::map<String, size_t>>
static std::unordered_map<size_t, std::unordered_map<std::string_view, size_t>> size_to_args
{
{1, {{}}},
{2, {{"format", 1}}},
@ -56,7 +56,7 @@ void TableFunctionS3::parseArgumentsImpl(const String & error_message, ASTs & ar
{6, {{"access_key_id", 1}, {"secret_access_key", 2}, {"format", 3}, {"structure", 4}, {"compression_method", 5}}}
};
std::map<String, size_t> args_to_idx;
std::unordered_map<std::string_view, size_t> args_to_idx;
/// For 4 arguments we support 2 possible variants:
/// s3(source, format, structure, compression_method) and s3(source, access_key_id, access_key_id, format)
/// We can distinguish them by looking at the 2-nd argument: check if it's a format name or not.

View File

@ -4,6 +4,7 @@ import argparse
import csv
import logging
import os
import re
import subprocess
import sys
import atexit
@ -112,17 +113,24 @@ def get_run_command(
def get_tests_to_run(pr_info):
result = set([])
result = set()
if pr_info.changed_files is None:
return []
for fpath in pr_info.changed_files:
if "tests/queries/0_stateless/0" in fpath:
logging.info("File %s changed and seems like stateless test", fpath)
if re.match(r"tests/queries/0_stateless/[0-9]{5}", fpath):
logging.info("File '%s' is changed and seems like a test", fpath)
fname = fpath.split("/")[3]
fname_without_ext = os.path.splitext(fname)[0]
# add '.' to the end of the test name not to run all tests with the same prefix
# e.g. we changed '00001_some_name.reference'
# and we have ['00001_some_name.sh', '00001_some_name_2.sql']
# so we want to run only '00001_some_name.sh'
result.add(fname_without_ext + ".")
elif "tests/queries/" in fpath:
# log suspicious changes from tests/ for debugging in case of any problems
logging.info("File '%s' is changed, but it doesn't look like a test", fpath)
return list(result)

View File

@ -76,7 +76,10 @@ def trim_for_log(s):
if not s:
return s
lines = s.splitlines()
return "\n".join(lines[:50] + ["#" * 100] + lines[-50:])
if len(lines) > 100:
return "\n".join(lines[:50] + ["#" * 100] + lines[-50:])
else:
return "\n".join(lines)
class HTTPError(Exception):
@ -939,7 +942,9 @@ class TestCase:
description += "\n"
description += trim_for_log(stderr)
description += "\n"
description += f"\nstdout:\n{stdout}\n"
description += "\nstdout:\n"
description += trim_for_log(stdout)
description += "\n"
if debug_log:
description += "\n"
description += debug_log

View File

@ -760,7 +760,7 @@ class ClickhouseIntegrationTestsRunner:
tests_to_run = get_changed_tests_to_run(pr_info, repo_path)
if not tests_to_run:
logging.info("No tests to run found")
logging.info("No integration tests to run found")
return "success", NO_CHANGES_MSG, [(NO_CHANGES_MSG, "OK")], ""
self._install_clickhouse(build_path)

View File

@ -373,6 +373,7 @@ class ClickHouseCluster:
self.env_file = p.join(self.instances_dir, DEFAULT_ENV_NAME)
self.env_variables = {}
self.env_variables["TSAN_OPTIONS"] = "second_deadlock_stack=1"
self.env_variables["ASAN_OPTIONS"] = "use_sigaltstack=0"
self.env_variables["CLICKHOUSE_WATCHDOG_ENABLE"] = "0"
self.env_variables["CLICKHOUSE_NATS_TLS_SECURE"] = "0"
self.up_called = False

View File

@ -1019,6 +1019,7 @@ def test_rabbitmq_many_inserts(rabbitmq_cluster):
), "ClickHouse lost some messages: {}".format(result)
@pytest.mark.skip(reason="Flaky")
def test_rabbitmq_overloaded_insert(rabbitmq_cluster):
instance.query(
"""

View File

@ -113,22 +113,21 @@ ExpressionTransform
(Expression)
ExpressionTransform × 4
(MergingAggregated)
Resize 1 → 4
SortingAggregatedTransform 4 → 1
MergingAggregatedBucketTransform × 4
Resize 1 → 4
GroupingAggregatedTransform 6 → 1
(Union)
(Aggregating)
MergingAggregatedBucketTransform × 4
Resize 1 → 4
FinishAggregatingInOrderTransform 4 → 1
AggregatingInOrderTransform × 4
(Expression)
ExpressionTransform × 4
(ReadFromMergeTree)
MergeTreeInOrder × 4 0 → 1
(ReadFromRemoteParallelReplicas)
MergingAggregatedBucketTransform × 4
Resize 1 → 4
FinishAggregatingInOrderTransform 3 → 1
(Union)
(Aggregating)
SortingAggregatedForMemoryBoundMergingTransform 4 → 1
MergingAggregatedBucketTransform × 4
Resize 1 → 4
FinishAggregatingInOrderTransform 4 → 1
AggregatingInOrderTransform × 4
(Expression)
ExpressionTransform × 4
(ReadFromMergeTree)
MergeTreeInOrder × 4 0 → 1
(ReadFromRemoteParallelReplicas)
select a, count() from pr_t group by a order by a limit 5 offset 500;
500 1000
501 1000

View File

@ -0,0 +1 @@
450000 450000

View File

@ -0,0 +1,22 @@
DROP TABLE IF EXISTS t_update_empty_nested;
CREATE TABLE t_update_empty_nested
(
`id` UInt32,
`nested.arr1` Array(UInt64),
)
ENGINE = MergeTree
ORDER BY id
SETTINGS min_bytes_for_wide_part = 0;
SET mutations_sync = 2;
INSERT INTO t_update_empty_nested SELECT 1, range(number % 10) FROM numbers(100000);
ALTER TABLE t_update_empty_nested ADD COLUMN `nested.arr2` Array(UInt64);
ALTER TABLE t_update_empty_nested UPDATE `nested.arr2` = `nested.arr1` WHERE 1;
SELECT * FROM t_update_empty_nested FORMAT Null;
SELECT sum(length(nested.arr1)), sum(length(nested.arr2)) FROM t_update_empty_nested;
DROP TABLE t_update_empty_nested;

View File

@ -0,0 +1,19 @@
http://auto.ru/chatay-baranta_bound-in-tanks.ru/forumyazan 2014-03-20 http://auto.ru/chatay-baranta_bound-in-tanks.ru/forumyazan
http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny 2014-03-17 http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny
http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny 2014-03-18 http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny
http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny 2014-03-19 http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny
http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny 2014-03-20 http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny
1
http://auto.ru/chatay-baranta_bound-in-tanks.ru/forumyazan 2014-03-20 http://auto.ru/chatay-baranta_bound-in-tanks.ru/forumyazan
http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny 2014-03-17 http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny
http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny 2014-03-18 http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny
http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny 2014-03-19 http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny
http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny 2014-03-20 http://auto.ru/chatay-baranta_bound-in-thankYou=ru/tver/zhanny
1
MergingAggregatedBucketTransform
FinishAggregatingInOrderTransform
SortingAggregatedForMemoryBoundMergingTransform
MergingAggregatedBucketTransform
FinishAggregatingInOrderTransform
AggregatingInOrderTransform
MergeTreeInOrder

View File

@ -0,0 +1,77 @@
#!/usr/bin/env bash
# shellcheck disable=SC2154
unset CLICKHOUSE_LOG_COMMENT
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=../shell_config.sh
. "$CURDIR"/../shell_config.sh
check_replicas_read_in_order() {
# to check this we actually look for at least one log message from MergeTreeInOrderSelectProcessor.
# hopefully logger's names are a bit more stable than log messages itself
$CLICKHOUSE_CLIENT -nq "
SYSTEM FLUSH LOGS;
SELECT COUNT() > 0
FROM system.text_log
WHERE query_id IN (SELECT query_id FROM system.query_log WHERE query_id != '$1' AND initial_query_id = '$1' AND event_date >= yesterday())
AND event_date >= yesterday() AND logger_name = 'MergeTreeInOrderSelectProcessor'"
}
# replicas should use reading in order following initiator's decision to execute aggregation in order.
# at some point we had a bug in this logic (see https://github.com/ClickHouse/ClickHouse/pull/45892#issue-1566140414)
test1() {
query_id="query_id_memory_bound_merging_$RANDOM$RANDOM"
$CLICKHOUSE_CLIENT --query_id="$query_id" -nq "
SET cluster_for_parallel_replicas = 'test_cluster_one_shard_three_replicas_localhost';
SELECT URL, EventDate, max(URL)
FROM remote(test_cluster_one_shard_two_replicas, test.hits)
WHERE CounterID = 1704509 AND UserID = 4322253409885123546
GROUP BY CounterID, URL, EventDate
ORDER BY URL, EventDate
LIMIT 5 OFFSET 10
SETTINGS optimize_aggregation_in_order = 1, enable_memory_bound_merging_of_aggregation_results = 1, allow_experimental_parallel_reading_from_replicas = 1, max_parallel_replicas = 3, use_hedged_requests = 0"
check_replicas_read_in_order $query_id
}
# replicas should use reading in order following initiator's decision to execute aggregation in order.
# at some point we had a bug in this logic (see https://github.com/ClickHouse/ClickHouse/pull/45892#issue-1566140414)
test2() {
query_id="query_id_memory_bound_merging_$RANDOM$RANDOM"
$CLICKHOUSE_CLIENT --query_id="$query_id" -nq "
SET cluster_for_parallel_replicas = 'test_cluster_one_shard_three_replicas_localhost';
SELECT URL, EventDate, max(URL)
FROM remote(test_cluster_one_shard_two_replicas, test.hits)
WHERE CounterID = 1704509 AND UserID = 4322253409885123546
GROUP BY URL, EventDate
ORDER BY URL, EventDate
LIMIT 5 OFFSET 10
SETTINGS optimize_aggregation_in_order = 1, enable_memory_bound_merging_of_aggregation_results = 1, allow_experimental_parallel_reading_from_replicas = 1, max_parallel_replicas = 3, use_hedged_requests = 0, query_plan_aggregation_in_order = 1"
check_replicas_read_in_order $query_id
}
test3() {
$CLICKHOUSE_CLIENT -nq "
SET cluster_for_parallel_replicas = 'test_cluster_one_shard_three_replicas_localhost';
SET max_threads = 16, prefer_localhost_replica = 1, read_in_order_two_level_merge_threshold = 1000, query_plan_aggregation_in_order = 1, distributed_aggregation_memory_efficient = 1;
SELECT replaceRegexpOne(explain, '^ *(\w+).*', '\\1')
FROM (
EXPLAIN PIPELINE
SELECT URL, EventDate, max(URL)
FROM test.hits
WHERE CounterID = 1704509 AND UserID = 4322253409885123546
GROUP BY URL, EventDate
SETTINGS optimize_aggregation_in_order = 1, enable_memory_bound_merging_of_aggregation_results = 1, allow_experimental_parallel_reading_from_replicas = 1, max_parallel_replicas = 3, use_hedged_requests = 0
)
WHERE explain LIKE '%Aggr%Transform%' OR explain LIKE '%InOrder%'"
}
test1
test2
test3