Add default credentials and custom headers for s3 table functions.

This commit is contained in:
Pervakov Grigorii 2020-06-01 20:16:09 +03:00
parent 70a57a84b5
commit bc9381406f
17 changed files with 281 additions and 18 deletions

2
contrib/aws vendored

@ -1 +1 @@
Subproject commit fb5c604525f5151d75a856462653e7e38b559b79
Subproject commit f7d9ce39f41323300044567be007c233338bb94a

View File

@ -43,7 +43,10 @@ services:
# Empty container to run proxy resolver.
resolver:
image: python:3
build:
context: ../../../docker/test/integration/
dockerfile: resolver/Dockerfile
network: host
ports:
- "4083:8080"
tty: true

View File

@ -0,0 +1,4 @@
# Helper docker container to run python bottle apps
FROM python:3
RUN python -m pip install bottle

View File

@ -4,6 +4,7 @@
# include <IO/S3Common.h>
# include <IO/WriteBufferFromString.h>
# include <Storages/StorageS3Settings.h>
# include <aws/core/auth/AWSCredentialsProvider.h>
# include <aws/core/utils/logging/LogMacros.h>
@ -60,6 +61,47 @@ public:
private:
Poco::Logger * log = &Poco::Logger::get("AWSClient");
};
class S3AuthSigner : public Aws::Client::AWSAuthV4Signer
{
public:
S3AuthSigner(
const Aws::Client::ClientConfiguration & clientConfiguration,
const Aws::Auth::AWSCredentials & credentials,
const DB::HeaderCollection & headers_)
: Aws::Client::AWSAuthV4Signer(
std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(credentials),
"s3",
clientConfiguration.region,
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
false)
, headers(headers_)
{
}
bool SignRequest(Aws::Http::HttpRequest & request, const char * region, bool signBody) const override
{
auto result = Aws::Client::AWSAuthV4Signer::SignRequest(request, region, signBody);
for (const auto & header : headers)
request.SetHeaderValue(header.name, header.value);
return result;
}
bool PresignRequest(
Aws::Http::HttpRequest & request,
const char * region,
const char * serviceName,
long long expirationTimeInSeconds) const override // NOLINT
{
auto result = Aws::Client::AWSAuthV4Signer::PresignRequest(request, region, serviceName, expirationTimeInSeconds);
for (const auto & header : headers)
request.SetHeaderValue(header.name, header.value);
return result;
}
private:
const DB::HeaderCollection headers;
};
}
namespace DB
@ -139,6 +181,25 @@ namespace S3
);
}
std::shared_ptr<Aws::S3::S3Client> ClientFactory::create( // NOLINT
const String & endpoint,
bool is_virtual_hosted_style,
const String & access_key_id,
const String & secret_access_key,
HeaderCollection headers)
{
Aws::Client::ClientConfiguration cfg;
if (!endpoint.empty())
cfg.endpointOverride = endpoint;
Aws::Auth::AWSCredentials credentials(access_key_id, secret_access_key);
return std::make_shared<Aws::S3::S3Client>(
std::make_shared<S3AuthSigner>(cfg, std::move(credentials), std::move(headers)),
std::move(cfg), // Client configuration.
is_virtual_hosted_style || cfg.endpointOverride.empty() // Use virtual addressing only if endpoint is not specified.
);
}
URI::URI(const Poco::URI & uri_)
{
/// Case when bucket name represented in domain name of S3 URL.

View File

@ -5,7 +5,7 @@
#if USE_AWS_S3
#include <Core/Types.h>
#include <Poco/URI.h>
#include <Interpreters/Context.h>
#include <aws/core/Aws.h>
namespace Aws::S3
@ -13,6 +13,12 @@ namespace Aws::S3
class S3Client;
}
namespace DB
{
struct HttpHeader;
using HeaderCollection = std::vector<HttpHeader>;
}
namespace DB::S3
{
@ -34,6 +40,14 @@ public:
bool is_virtual_hosted_style,
const String & access_key_id,
const String & secret_access_key);
std::shared_ptr<Aws::S3::S3Client> create(
const String & endpoint,
bool is_virtual_hosted_style,
const String & access_key_id,
const String & secret_access_key,
HeaderCollection headers);
private:
ClientFactory();

View File

@ -22,6 +22,7 @@
#include <Storages/MergeTree/MergeList.h>
#include <Storages/MergeTree/MergeTreeSettings.h>
#include <Storages/CompressionCodecSelector.h>
#include <Storages/StorageS3Settings.h>
#include <Disks/DiskLocal.h>
#include <TableFunctions/TableFunctionFactory.h>
#include <Interpreters/ActionLocksManager.h>
@ -351,6 +352,7 @@ struct ContextShared
String format_schema_path; /// Path to a directory that contains schema files used by input formats.
ActionLocksManagerPtr action_locks_manager; /// Set of storages' action lockers
std::optional<SystemLogs> system_logs; /// Used to log queries and operations on parts
std::optional<StorageS3Settings> storage_s3_settings; /// Settings of S3 storage
RemoteHostFilter remote_host_filter; /// Allowed URL from config.xml
@ -1764,6 +1766,11 @@ void Context::updateStorageConfiguration(const Poco::Util::AbstractConfiguration
LOG_ERROR(shared->log, "An error has occured while reloading storage policies, storage policies were not applied: {}", e.message());
}
}
if (shared->storage_s3_settings)
{
shared->storage_s3_settings->loadFromConfig("s3", config);
}
}
@ -1782,6 +1789,18 @@ const MergeTreeSettings & Context::getMergeTreeSettings() const
return *shared->merge_tree_settings;
}
const StorageS3Settings & Context::getStorageS3Settings() const
{
auto lock = getLock();
if (!shared->storage_s3_settings)
{
const auto & config = getConfigRef();
shared->storage_s3_settings.emplace().loadFromConfig("s3", config);
}
return *shared->storage_s3_settings;
}
void Context::checkCanBeDropped(const String & database, const String & table, const size_t & size, const size_t & max_size_to_drop) const
{

View File

@ -81,6 +81,7 @@ class TextLog;
class TraceLog;
class MetricLog;
struct MergeTreeSettings;
class StorageS3Settings;
class IDatabase;
class DDLWorker;
class ITableFunction;
@ -531,6 +532,7 @@ public:
std::shared_ptr<PartLog> getPartLog(const String & part_database);
const MergeTreeSettings & getMergeTreeSettings() const;
const StorageS3Settings & getStorageS3Settings() const;
/// Prevents DROP TABLE if its size is greater than max_size (50GB by default, max_size=0 turn off this check)
void setMaxTableSizeToDrop(size_t max_size);

View File

@ -5,6 +5,7 @@
#include <IO/S3Common.h>
#include <Storages/StorageFactory.h>
#include <Storages/StorageS3.h>
#include <Storages/StorageS3Settings.h>
#include <Interpreters/Context.h>
#include <Interpreters/evaluateConstantExpression.h>
@ -23,6 +24,7 @@
#include <DataTypes/DataTypeString.h>
#include <aws/core/auth/AWSCredentials.h>
#include <aws/s3/S3Client.h>
#include <aws/s3/model/ListObjectsV2Request.h>
@ -200,18 +202,24 @@ StorageS3::StorageS3(
, format_name(format_name_)
, min_upload_part_size(min_upload_part_size_)
, compression_method(compression_method_)
, client(S3::ClientFactory::instance().create(uri_.endpoint, uri_.is_virtual_hosted_style, access_key_id_, secret_access_key_))
{
context_global.getRemoteHostFilter().checkURL(uri_.uri);
setColumns(columns_);
setConstraints(constraints_);
auto settings = context_.getStorageS3Settings().getSettings(uri.endpoint);
Aws::Auth::AWSCredentials credentials(access_key_id_, secret_access_key_);
if (access_key_id_.empty())
credentials = Aws::Auth::AWSCredentials(std::move(settings.access_key_id), std::move(settings.secret_access_key));
client = S3::ClientFactory::instance().create(
uri_.endpoint, uri_.is_virtual_hosted_style, access_key_id_, secret_access_key_, std::move(settings.headers));
}
namespace
{
/* "Recursive" directory listing with matched paths as a result.
/* "Recursive" directory listing with matched paths as a result.
* Have the same method in StorageFile.
*/
Strings listFilesWithRegexpMatching(Aws::S3::S3Client & client, const S3::URI & globbed_uri)

View File

@ -0,0 +1,57 @@
#include <Storages/StorageS3Settings.h>
#include <Poco/Util/AbstractConfiguration.h>
#include <Common/Exception.h>
namespace DB
{
namespace ErrorCodes
{
extern const int INVALID_CONFIG_PARAMETER;
}
void StorageS3Settings::loadFromConfig(const String & config_elem, const Poco::Util::AbstractConfiguration & config)
{
auto lock = std::unique_lock(mutex);
settings.clear();
if (!config.has(config_elem))
return;
Poco::Util::AbstractConfiguration::Keys config_keys;
config.keys(config_elem, config_keys);
for (const String & key : config_keys)
{
auto endpoint = config.getString(config_elem + "." + key + ".endpoint");
auto access_key_id = config.getString(config_elem + "." + key + ".access_key_id", "");
auto secret_access_key = config.getString(config_elem + "." + key + ".secret_access_key", "");
HeaderCollection headers;
Poco::Util::AbstractConfiguration::Keys subconfig_keys;
config.keys(config_elem + "." + key, subconfig_keys);
for (const String & subkey : subconfig_keys)
{
if (subkey.starts_with("header"))
{
auto header_str = config.getString(config_elem + "." + key + "." + subkey);
auto delimiter = header_str.find(':');
if (delimiter == String::npos)
throw Exception("Malformed s3 header value", ErrorCodes::INVALID_CONFIG_PARAMETER);
headers.emplace_back(HttpHeader{header_str.substr(0, delimiter), header_str.substr(delimiter + 1, String::npos)});
}
}
settings.emplace(endpoint, S3AuthSettings{std::move(access_key_id), std::move(secret_access_key), std::move(headers)});
}
}
S3AuthSettings StorageS3Settings::getSettings(const String & endpoint) const
{
auto lock = std::unique_lock(mutex);
if (auto setting = settings.find(endpoint); setting != settings.end())
return setting->second;
return {};
}
}

View File

@ -0,0 +1,46 @@
#pragma once
#include <map>
#include <memory>
#include <mutex>
#include <Core/Types.h>
namespace Poco::Util
{
class AbstractConfiguration;
}
namespace DB
{
struct HttpHeader
{
const String name;
const String value;
};
using HeaderCollection = std::vector<HttpHeader>;
struct S3AuthSettings
{
const String access_key_id;
const String secret_access_key;
const HeaderCollection headers;
};
/// Settings for the StorageS3.
class StorageS3Settings
{
public:
StorageS3Settings() = default;
void loadFromConfig(const String & config_elem, const Poco::Util::AbstractConfiguration & config);
S3AuthSettings getSettings(const String & endpoint) const;
private:
mutable std::mutex mutex;
std::map<const String, const S3AuthSettings> settings;
};
}

View File

@ -164,6 +164,7 @@ SRCS(
StorageMySQL.cpp
StorageNull.cpp
StorageReplicatedMergeTree.cpp
StorageS3Settings.cpp
StorageSet.cpp
StorageStripeLog.cpp
StorageTinyLog.cpp

View File

@ -1,4 +0,0 @@
#!/bin/bash
pip install bottle
python resolver.py

View File

@ -14,9 +14,7 @@ def run_resolver(cluster):
current_dir = os.path.dirname(__file__)
cluster.copy_file_to_container(container_id, os.path.join(current_dir, "proxy-resolver", "resolver.py"),
"resolver.py")
cluster.copy_file_to_container(container_id, os.path.join(current_dir, "proxy-resolver", "entrypoint.sh"),
"entrypoint.sh")
cluster.exec_in_container(container_id, ["/bin/bash", "entrypoint.sh"], detach=True)
cluster.exec_in_container(container_id, ["python", "resolver.py"], detach=True)
@pytest.fixture(scope="module")

View File

@ -0,0 +1,8 @@
<yandex>
<s3>
<s3_mock>
<endpoint>http://resolver:8080</endpoint>
<header>Authorization: Bearer TOKEN</header>
</s3_mock>
</s3>
</yandex>

View File

@ -0,0 +1,17 @@
from bottle import abort, route, run, request
@route('/<_bucket>/<_path>')
def server(_bucket, _path):
for name in request.headers:
if name == 'Authorization' and request.headers[name] == u'Bearer TOKEN':
return '1, 2, 3'
abort(403)
@route('/')
def ping():
return 'OK'
run(host='0.0.0.0', port=8080)

View File

@ -2,6 +2,7 @@ import json
import logging
import random
import threading
import os
import pytest
@ -9,7 +10,6 @@ from helpers.cluster import ClickHouseCluster, ClickHouseInstance
import helpers.client
logging.getLogger().setLevel(logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler())
@ -82,14 +82,16 @@ def get_nginx_access_logs():
def cluster():
try:
cluster = ClickHouseCluster(__file__)
cluster.add_instance("restricted_dummy", main_configs=["configs/config_for_test_remote_host_filter.xml"], with_minio=True)
cluster.add_instance("dummy", with_minio=True)
cluster.add_instance("restricted_dummy", main_configs=["configs/config_for_test_remote_host_filter.xml"],
with_minio=True)
cluster.add_instance("dummy", with_minio=True, main_configs=["configs/defaultS3.xml"])
logging.info("Starting cluster...")
cluster.start()
logging.info("Cluster started")
prepare_s3_bucket(cluster)
logging.info("S3 bucket created")
run_s3_mock(cluster)
yield cluster
finally:
@ -199,14 +201,15 @@ def test_put_get_with_globs(cluster):
for j in range(10):
path = "{}_{}/{}.csv".format(i, random.choice(['a', 'b', 'c', 'd']), j)
max_path = max(path, max_path)
values = "({},{},{})".format(i, j, i+j)
values = "({},{},{})".format(i, j, i + j)
query = "insert into table function s3('http://{}:{}/{}/{}', 'CSV', '{}') values {}".format(
cluster.minio_host, cluster.minio_port, bucket, path, table_format, values)
run_query(instance, query)
query = "select sum(column1), sum(column2), sum(column3), min(_file), max(_path) from s3('http://{}:{}/{}/*_{{a,b,c,d}}/%3f.csv', 'CSV', '{}')".format(
cluster.minio_redirect_host, cluster.minio_redirect_port, bucket, table_format)
assert run_query(instance, query).splitlines() == ["450\t450\t900\t0.csv\t{bucket}/{max_path}".format(bucket=bucket, max_path=max_path)]
assert run_query(instance, query).splitlines() == [
"450\t450\t900\t0.csv\t{bucket}/{max_path}".format(bucket=bucket, max_path=max_path)]
# Test multipart put.
@ -307,3 +310,29 @@ def test_s3_glob_scheherazade(cluster):
query = "select count(), sum(column1), sum(column2), sum(column3) from s3('http://{}:{}/{}/night_*/tale.csv', 'CSV', '{}')".format(
cluster.minio_redirect_host, cluster.minio_redirect_port, bucket, table_format)
assert run_query(instance, query).splitlines() == ["1001\t1001\t1001\t1001"]
def run_s3_mock(cluster):
logging.info("Starting s3 mock")
container_id = cluster.get_container_id('resolver')
current_dir = os.path.dirname(__file__)
cluster.copy_file_to_container(container_id, os.path.join(current_dir, "s3_mock", "mock_s3.py"), "mock_s3.py")
cluster.exec_in_container(container_id, ["python", "mock_s3.py"], detach=True)
logging.info("S3 mock started")
# Test get values in CSV format with default settings.
def test_get_csv_default(cluster):
ping_response = cluster.exec_in_container(cluster.get_container_id('resolver'), ["curl", "-s", "http://resolver:8080"])
assert ping_response == 'OK', 'Expected "OK", but got "{}"'.format(ping_response)
table_format = "column1 UInt32, column2 UInt32, column3 UInt32"
filename = "test.csv"
get_query = "select * from s3('http://resolver:8080/{bucket}/{file}', 'CSV', '{table_format}')".format(
bucket=cluster.minio_restricted_bucket,
file=filename,
table_format=table_format)
instance = cluster.instances["dummy"] # type: ClickHouseInstance
result = run_query(instance, get_query)
assert result == '1\t2\t3\n'