Merge pull request #21457 from GrigoryPervakov/master

recreate S3 client if credentials changed
This commit is contained in:
Vladimir 2021-03-11 15:18:39 +03:00 committed by GitHub
commit 08f312b1c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 96 additions and 35 deletions

View File

@ -203,6 +203,9 @@ StorageS3::StorageS3(
const String & compression_method_)
: IStorage(table_id_)
, uri(uri_)
, access_key_id(access_key_id_)
, secret_access_key(secret_access_key_)
, max_connections(max_connections_)
, global_context(context_.getGlobalContext())
, format_name(format_name_)
, min_upload_part_size(min_upload_part_size_)
@ -215,29 +218,7 @@ StorageS3::StorageS3(
storage_metadata.setColumns(columns_);
storage_metadata.setConstraints(constraints_);
setInMemoryMetadata(storage_metadata);
auto settings = context_.getStorageS3Settings().getSettings(uri.uri.toString());
Aws::Auth::AWSCredentials credentials(access_key_id_, secret_access_key_);
if (access_key_id_.empty())
credentials = Aws::Auth::AWSCredentials(std::move(settings.access_key_id), std::move(settings.secret_access_key));
S3::PocoHTTPClientConfiguration client_configuration = S3::ClientFactory::instance().createClientConfiguration(
context_.getRemoteHostFilter(),
context_.getGlobalContext().getSettingsRef().s3_max_redirects);
client_configuration.endpointOverride = uri_.endpoint;
client_configuration.maxConnections = max_connections_;
client = S3::ClientFactory::instance().create(
client_configuration,
uri_.is_virtual_hosted_style,
credentials.GetAWSAccessKeyId(),
credentials.GetAWSSecretKey(),
settings.server_side_encryption_customer_key_base64,
std::move(settings.headers),
settings.use_environment_credentials.value_or(global_context.getConfigRef().getBool("s3.use_environment_credentials", false))
);
updateAuthSettings(context_);
}
@ -309,6 +290,8 @@ Pipe StorageS3::read(
size_t max_block_size,
unsigned num_streams)
{
updateAuthSettings(context);
Pipes pipes;
bool need_path_column = false;
bool need_file_column = false;
@ -342,8 +325,9 @@ Pipe StorageS3::read(
return pipe;
}
BlockOutputStreamPtr StorageS3::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, const Context & /*context*/)
BlockOutputStreamPtr StorageS3::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, const Context & context)
{
updateAuthSettings(context);
return std::make_shared<StorageS3BlockOutputStream>(
format_name,
metadata_snapshot->getSampleBlock(),
@ -356,6 +340,38 @@ BlockOutputStreamPtr StorageS3::write(const ASTPtr & /*query*/, const StorageMet
max_single_part_upload_size);
}
void StorageS3::updateAuthSettings(const Context & context)
{
auto settings = context.getStorageS3Settings().getSettings(uri.uri.toString());
if (client && (!access_key_id.empty() || settings == auth_settings))
return;
Aws::Auth::AWSCredentials credentials(access_key_id, secret_access_key);
HeaderCollection headers;
if (access_key_id.empty())
{
credentials = Aws::Auth::AWSCredentials(settings.access_key_id, settings.secret_access_key);
headers = settings.headers;
}
S3::PocoHTTPClientConfiguration client_configuration = S3::ClientFactory::instance().createClientConfiguration(
context.getRemoteHostFilter(), context.getGlobalContext().getSettingsRef().s3_max_redirects);
client_configuration.endpointOverride = uri.endpoint;
client_configuration.maxConnections = max_connections;
client = S3::ClientFactory::instance().create(
client_configuration,
uri.is_virtual_hosted_style,
credentials.GetAWSAccessKeyId(),
credentials.GetAWSSecretKey(),
settings.server_side_encryption_customer_key_base64,
std::move(headers),
settings.use_environment_credentials.value_or(global_context.getConfigRef().getBool("s3.use_environment_credentials", false)));
auth_settings = std::move(settings);
}
void registerStorageS3Impl(const String & name, StorageFactory & factory)
{
factory.registerStorage(name, [](const StorageFactory::Arguments & args)

View File

@ -5,6 +5,7 @@
#if USE_AWS_S3
#include <Storages/IStorage.h>
#include <Storages/StorageS3Settings.h>
#include <Poco/URI.h>
#include <common/logger_useful.h>
#include <ext/shared_ptr_helper.h>
@ -57,7 +58,10 @@ public:
NamesAndTypesList getVirtuals() const override;
private:
S3::URI uri;
const S3::URI uri;
const String access_key_id;
const String secret_access_key;
const UInt64 max_connections;
const Context & global_context;
String format_name;
@ -66,6 +70,9 @@ private:
String compression_method;
std::shared_ptr<Aws::S3::S3Client> client;
String name;
S3AuthSettings auth_settings;
void updateAuthSettings(const Context & context);
};
}

View File

@ -14,24 +14,32 @@ class AbstractConfiguration;
namespace DB
{
struct HttpHeader
{
const String name;
const String value;
String name;
String value;
inline bool operator==(const HttpHeader & other) const { return name == other.name && value == other.value; }
};
using HeaderCollection = std::vector<HttpHeader>;
struct S3AuthSettings
{
const String access_key_id;
const String secret_access_key;
const String server_side_encryption_customer_key_base64;
String access_key_id;
String secret_access_key;
String server_side_encryption_customer_key_base64;
const HeaderCollection headers;
HeaderCollection headers;
std::optional<bool> use_environment_credentials;
inline bool operator==(const S3AuthSettings & other) const
{
return access_key_id == other.access_key_id && secret_access_key == other.secret_access_key
&& server_side_encryption_customer_key_base64 == other.server_side_encryption_customer_key_base64 && headers == other.headers
&& use_environment_credentials == other.use_environment_credentials;
}
};
/// Settings for the StorageS3.

View File

@ -14,6 +14,9 @@ from helpers.cluster import ClickHouseCluster, ClickHouseInstance
logging.getLogger().setLevel(logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler())
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
CONFIG_PATH = os.path.join(SCRIPT_DIR, './_instances/dummy/configs/config.d/defaultS3.xml')
# Creates S3 bucket for tests and allows anonymous read-write access to it.
def prepare_s3_bucket(cluster):
@ -85,7 +88,8 @@ def cluster():
cluster.add_instance("restricted_dummy", main_configs=["configs/config_for_test_remote_host_filter.xml"],
with_minio=True)
cluster.add_instance("dummy", with_minio=True, main_configs=["configs/defaultS3.xml"])
cluster.add_instance("s3_max_redirects", with_minio=True, main_configs=["configs/defaultS3.xml"], user_configs=["configs/s3_max_redirects.xml"])
cluster.add_instance("s3_max_redirects", with_minio=True, main_configs=["configs/defaultS3.xml"],
user_configs=["configs/s3_max_redirects.xml"])
logging.info("Starting cluster...")
cluster.start()
logging.info("Cluster started")
@ -277,9 +281,9 @@ def test_put_get_with_globs(cluster):
# Test multipart put.
@pytest.mark.parametrize("maybe_auth,positive", [
("", True),
("", True)
# ("'minio','minio123',",True), Redirect with credentials not working with nginx.
("'wrongid','wrongkey',", False)
# ("'wrongid','wrongkey',", False) ClickHouse crashes in some time after this test, local integration tests run fails.
])
def test_multipart_put(cluster, maybe_auth, positive):
# type: (ClickHouseCluster) -> None
@ -397,6 +401,16 @@ def run_s3_mock(cluster):
logging.info("S3 mock started")
def replace_config(old, new):
config = open(CONFIG_PATH, 'r')
config_lines = config.readlines()
config.close()
config_lines = [line.replace(old, new) for line in config_lines]
config = open(CONFIG_PATH, 'w')
config.writelines(config_lines)
config.close()
def test_custom_auth_headers(cluster):
table_format = "column1 UInt32, column2 UInt32, column3 UInt32"
filename = "test.csv"
@ -409,6 +423,22 @@ def test_custom_auth_headers(cluster):
result = run_query(instance, get_query)
assert result == '1\t2\t3\n'
instance.query(
"CREATE TABLE test ({table_format}) ENGINE = S3('http://resolver:8080/{bucket}/{file}', 'CSV')".format(
bucket=cluster.minio_restricted_bucket,
file=filename,
table_format=table_format
))
assert run_query(instance, "SELECT * FROM test") == '1\t2\t3\n'
replace_config("<header>Authorization: Bearer TOKEN", "<header>Authorization: Bearer INVALID_TOKEN")
instance.query("SYSTEM RELOAD CONFIG")
ret, err = instance.query_and_get_answer_with_error("SELECT * FROM test")
assert ret == "" and err != ""
replace_config("<header>Authorization: Bearer INVALID_TOKEN", "<header>Authorization: Bearer TOKEN")
instance.query("SYSTEM RELOAD CONFIG")
assert run_query(instance, "SELECT * FROM test") == '1\t2\t3\n'
def test_custom_auth_headers_exclusion(cluster):
table_format = "column1 UInt32, column2 UInt32, column3 UInt32"