Merge pull request #7992 from ClickHouse/excitoon-favorites-s3auth

Merging #7623
This commit is contained in:
alexey-milovidov 2019-12-03 15:21:12 +03:00 committed by GitHub
commit f09c29a2b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 385 additions and 148 deletions

View File

@ -464,6 +464,7 @@ namespace ErrorCodes
extern const int CANNOT_GET_CREATE_DICTIONARY_QUERY = 487;
extern const int UNKNOWN_DICTIONARY = 488;
extern const int INCORRECT_DICTIONARY_DEFINITION = 489;
extern const int CANNOT_FORMAT_DATETIME = 490;
extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000;

View File

@ -91,19 +91,7 @@ private:
template <typename T>
static inline void writeNumber2(char * p, T v)
{
static const char digits[201] =
"00010203040506070809"
"10111213141516171819"
"20212223242526272829"
"30313233343536373839"
"40414243444546474849"
"50515253545556575859"
"60616263646566676869"
"70717273747576777879"
"80818283848586878889"
"90919293949596979899";
memcpy(p, &digits[v * 2], 2);
memcpy(p, &digits100[v * 2], 2);
}
template <typename T>

View File

@ -1,6 +1,7 @@
#include <IO/ReadBufferFromS3.h>
#include <IO/ReadBufferFromIStream.h>
#include <IO/S3Common.h>
#include <common/logger_useful.h>
@ -10,13 +11,12 @@ namespace DB
const int DEFAULT_S3_MAX_FOLLOW_GET_REDIRECT = 2;
ReadBufferFromS3::ReadBufferFromS3(Poco::URI uri_,
const ConnectionTimeouts & timeouts,
const Poco::Net::HTTPBasicCredentials & credentials,
size_t buffer_size_)
ReadBufferFromS3::ReadBufferFromS3(const Poco::URI & uri_,
const String & access_key_id_,
const String & secret_access_key_,
const ConnectionTimeouts & timeouts)
: ReadBuffer(nullptr, 0)
, uri {uri_}
, method {Poco::Net::HTTPRequest::HTTP_GET}
, session {makeHTTPSession(uri_, timeouts)}
{
Poco::Net::HTTPResponse response;
@ -28,11 +28,13 @@ ReadBufferFromS3::ReadBufferFromS3(Poco::URI uri_,
if (uri.getPath().empty())
uri.setPath("/");
request = std::make_unique<Poco::Net::HTTPRequest>(method, uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1);
request = std::make_unique<Poco::Net::HTTPRequest>(
Poco::Net::HTTPRequest::HTTP_GET,
uri.getPathAndQuery(),
Poco::Net::HTTPRequest::HTTP_1_1);
request->setHost(uri.getHost()); // use original, not resolved host name in header
if (!credentials.getUsername().empty())
credentials.authenticate(*request);
S3Helper::authenticateRequest(*request, access_key_id_, secret_access_key_);
LOG_TRACE((&Logger::get("ReadBufferFromS3")), "Sending request to " << uri.toString());
@ -54,7 +56,7 @@ ReadBufferFromS3::ReadBufferFromS3(Poco::URI uri_,
}
assertResponseIsOk(*request, response, *istr);
impl = std::make_unique<ReadBufferFromIStream>(*istr, buffer_size_);
impl = std::make_unique<ReadBufferFromIStream>(*istr, DBMS_DEFAULT_BUFFER_SIZE);
}

View File

@ -17,17 +17,15 @@ class ReadBufferFromS3 : public ReadBuffer
{
protected:
Poco::URI uri;
std::string method;
HTTPSessionPtr session;
std::istream * istr; /// owned by session
std::unique_ptr<ReadBuffer> impl;
public:
explicit ReadBufferFromS3(Poco::URI uri_,
const ConnectionTimeouts & timeouts = {},
const Poco::Net::HTTPBasicCredentials & credentials = {},
size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE);
explicit ReadBufferFromS3(const Poco::URI & uri_,
const String & access_key_id_,
const String & secret_access_key_,
const ConnectionTimeouts & timeouts = {});
bool nextImpl() override;
};

60
dbms/src/IO/S3Common.cpp Normal file
View File

@ -0,0 +1,60 @@
#include <IO/S3Common.h>
#include <IO/WriteHelpers.h>
#include <IO/WriteBufferFromString.h>
#include <iterator>
#include <sstream>
#include <Poco/Base64Encoder.h>
#include <Poco/HMACEngine.h>
#include <Poco/SHA1Engine.h>
#include <Poco/URI.h>
namespace DB
{
namespace ErrorCodes
{
extern const int CANNOT_FORMAT_DATETIME;
}
void S3Helper::authenticateRequest(
Poco::Net::HTTPRequest & request,
const String & access_key_id,
const String & secret_access_key)
{
/// See https://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html
if (access_key_id.empty())
return;
/// Limitations:
/// 1. Virtual hosted-style requests are not supported (e.g. `http://johnsmith.net.s3.amazonaws.com/homepage.html`).
/// 2. AMZ headers are not supported (TODO).
if (!request.has("Date"))
{
WriteBufferFromOwnString out;
writeDateTimeTextRFC1123(time(nullptr), out, DateLUT::instance("UTC"));
request.set("Date", out.str());
}
String string_to_sign = request.getMethod() + "\n"
+ request.get("Content-MD5", "") + "\n"
+ request.get("Content-Type", "") + "\n"
+ request.get("Date") + "\n"
+ Poco::URI(request.getURI()).getPathAndQuery();
Poco::HMACEngine<Poco::SHA1Engine> engine(secret_access_key);
engine.update(string_to_sign);
auto digest = engine.digest();
std::ostringstream signature;
Poco::Base64Encoder encoder(signature);
std::copy(digest.begin(), digest.end(), std::ostream_iterator<char>(encoder));
encoder.close();
request.set("Authorization", "AWS " + access_key_id + ":" + signature.str());
}
}

19
dbms/src/IO/S3Common.h Normal file
View File

@ -0,0 +1,19 @@
#pragma once
#include <Core/Types.h>
#include <Poco/Net/HTTPRequest.h>
namespace DB
{
namespace S3Helper
{
void authenticateRequest(
Poco::Net::HTTPRequest & request,
const String & access_key_id,
const String & secret_access_key);
};
}

View File

@ -1,5 +1,6 @@
#include <IO/WriteBufferFromS3.h>
#include <IO/S3Common.h>
#include <IO/WriteHelpers.h>
#include <Poco/DOM/AutoPtr.h>
@ -30,22 +31,22 @@ namespace ErrorCodes
WriteBufferFromS3::WriteBufferFromS3(
const Poco::URI & uri_,
const String & access_key_id_,
const String & secret_access_key_,
size_t minimum_upload_part_size_,
const ConnectionTimeouts & timeouts_,
const Poco::Net::HTTPBasicCredentials & credentials, size_t buffer_size_
)
: BufferWithOwnMemory<WriteBuffer>(buffer_size_, nullptr, 0)
const ConnectionTimeouts & timeouts_)
: BufferWithOwnMemory<WriteBuffer>(DBMS_DEFAULT_BUFFER_SIZE, nullptr, 0)
, uri {uri_}
, access_key_id {access_key_id_}
, secret_access_key {secret_access_key_}
, minimum_upload_part_size {minimum_upload_part_size_}
, timeouts {timeouts_}
, auth_request {Poco::Net::HTTPRequest::HTTP_PUT, uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1}
, temporary_buffer {std::make_unique<WriteBufferFromString>(buffer_string)}
, last_part_size {0}
{
if (!credentials.getUsername().empty())
credentials.authenticate(auth_request);
initiate();
/// FIXME: Implement rest of S3 authorization.
}
@ -113,11 +114,7 @@ void WriteBufferFromS3::initiate()
request_ptr = std::make_unique<Poco::Net::HTTPRequest>(Poco::Net::HTTPRequest::HTTP_POST, initiate_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1);
request_ptr->setHost(initiate_uri.getHost()); // use original, not resolved host name in header
if (auth_request.hasCredentials())
{
Poco::Net::HTTPBasicCredentials credentials(auth_request);
credentials.authenticate(*request_ptr);
}
S3Helper::authenticateRequest(*request_ptr, access_key_id, secret_access_key);
request_ptr->setContentLength(0);
@ -179,11 +176,7 @@ void WriteBufferFromS3::writePart(const String & data)
request_ptr = std::make_unique<Poco::Net::HTTPRequest>(Poco::Net::HTTPRequest::HTTP_PUT, part_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1);
request_ptr->setHost(part_uri.getHost()); // use original, not resolved host name in header
if (auth_request.hasCredentials())
{
Poco::Net::HTTPBasicCredentials credentials(auth_request);
credentials.authenticate(*request_ptr);
}
S3Helper::authenticateRequest(*request_ptr, access_key_id, secret_access_key);
request_ptr->setExpectContinue(true);
@ -252,11 +245,7 @@ void WriteBufferFromS3::complete()
request_ptr = std::make_unique<Poco::Net::HTTPRequest>(Poco::Net::HTTPRequest::HTTP_POST, complete_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1);
request_ptr->setHost(complete_uri.getHost()); // use original, not resolved host name in header
if (auth_request.hasCredentials())
{
Poco::Net::HTTPBasicCredentials credentials(auth_request);
credentials.authenticate(*request_ptr);
}
S3Helper::authenticateRequest(*request_ptr, access_key_id, secret_access_key);
request_ptr->setExpectContinue(true);

View File

@ -21,9 +21,10 @@ class WriteBufferFromS3 : public BufferWithOwnMemory<WriteBuffer>
{
private:
Poco::URI uri;
String access_key_id;
String secret_access_key;
size_t minimum_upload_part_size;
ConnectionTimeouts timeouts;
Poco::Net::HTTPRequest auth_request;
String buffer_string;
std::unique_ptr<WriteBufferFromString> temporary_buffer;
size_t last_part_size;
@ -35,10 +36,10 @@ private:
public:
explicit WriteBufferFromS3(const Poco::URI & uri,
const String & access_key_id,
const String & secret_access_key,
size_t minimum_upload_part_size_,
const ConnectionTimeouts & timeouts = {},
const Poco::Net::HTTPBasicCredentials & credentials = {},
size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE);
const ConnectionTimeouts & timeouts = {});
void nextImpl() override;

View File

@ -568,45 +568,46 @@ inline void writeUUIDText(const UUID & uuid, WriteBuffer & buf)
buf.write(s, sizeof(s));
}
static const char digits100[201] =
"00010203040506070809"
"10111213141516171819"
"20212223242526272829"
"30313233343536373839"
"40414243444546474849"
"50515253545556575859"
"60616263646566676869"
"70717273747576777879"
"80818283848586878889"
"90919293949596979899";
/// in YYYY-MM-DD format
template <char delimiter = '-'>
inline void writeDateText(const LocalDate & date, WriteBuffer & buf)
{
static const char digits[201] =
"00010203040506070809"
"10111213141516171819"
"20212223242526272829"
"30313233343536373839"
"40414243444546474849"
"50515253545556575859"
"60616263646566676869"
"70717273747576777879"
"80818283848586878889"
"90919293949596979899";
if (buf.position() + 10 <= buf.buffer().end())
{
memcpy(buf.position(), &digits[date.year() / 100 * 2], 2);
memcpy(buf.position(), &digits100[date.year() / 100 * 2], 2);
buf.position() += 2;
memcpy(buf.position(), &digits[date.year() % 100 * 2], 2);
memcpy(buf.position(), &digits100[date.year() % 100 * 2], 2);
buf.position() += 2;
*buf.position() = delimiter;
++buf.position();
memcpy(buf.position(), &digits[date.month() * 2], 2);
memcpy(buf.position(), &digits100[date.month() * 2], 2);
buf.position() += 2;
*buf.position() = delimiter;
++buf.position();
memcpy(buf.position(), &digits[date.day() * 2], 2);
memcpy(buf.position(), &digits100[date.day() * 2], 2);
buf.position() += 2;
}
else
{
buf.write(&digits[date.year() / 100 * 2], 2);
buf.write(&digits[date.year() % 100 * 2], 2);
buf.write(&digits100[date.year() / 100 * 2], 2);
buf.write(&digits100[date.year() % 100 * 2], 2);
buf.write(delimiter);
buf.write(&digits[date.month() * 2], 2);
buf.write(&digits100[date.month() * 2], 2);
buf.write(delimiter);
buf.write(&digits[date.day() * 2], 2);
buf.write(&digits100[date.day() * 2], 2);
}
}
@ -628,59 +629,47 @@ inline void writeDateText(DayNum date, WriteBuffer & buf)
template <char date_delimeter = '-', char time_delimeter = ':', char between_date_time_delimiter = ' '>
inline void writeDateTimeText(const LocalDateTime & datetime, WriteBuffer & buf)
{
static const char digits[201] =
"00010203040506070809"
"10111213141516171819"
"20212223242526272829"
"30313233343536373839"
"40414243444546474849"
"50515253545556575859"
"60616263646566676869"
"70717273747576777879"
"80818283848586878889"
"90919293949596979899";
if (buf.position() + 19 <= buf.buffer().end())
{
memcpy(buf.position(), &digits[datetime.year() / 100 * 2], 2);
memcpy(buf.position(), &digits100[datetime.year() / 100 * 2], 2);
buf.position() += 2;
memcpy(buf.position(), &digits[datetime.year() % 100 * 2], 2);
memcpy(buf.position(), &digits100[datetime.year() % 100 * 2], 2);
buf.position() += 2;
*buf.position() = date_delimeter;
++buf.position();
memcpy(buf.position(), &digits[datetime.month() * 2], 2);
memcpy(buf.position(), &digits100[datetime.month() * 2], 2);
buf.position() += 2;
*buf.position() = date_delimeter;
++buf.position();
memcpy(buf.position(), &digits[datetime.day() * 2], 2);
memcpy(buf.position(), &digits100[datetime.day() * 2], 2);
buf.position() += 2;
*buf.position() = between_date_time_delimiter;
++buf.position();
memcpy(buf.position(), &digits[datetime.hour() * 2], 2);
memcpy(buf.position(), &digits100[datetime.hour() * 2], 2);
buf.position() += 2;
*buf.position() = time_delimeter;
++buf.position();
memcpy(buf.position(), &digits[datetime.minute() * 2], 2);
memcpy(buf.position(), &digits100[datetime.minute() * 2], 2);
buf.position() += 2;
*buf.position() = time_delimeter;
++buf.position();
memcpy(buf.position(), &digits[datetime.second() * 2], 2);
memcpy(buf.position(), &digits100[datetime.second() * 2], 2);
buf.position() += 2;
}
else
{
buf.write(&digits[datetime.year() / 100 * 2], 2);
buf.write(&digits[datetime.year() % 100 * 2], 2);
buf.write(&digits100[datetime.year() / 100 * 2], 2);
buf.write(&digits100[datetime.year() % 100 * 2], 2);
buf.write(date_delimeter);
buf.write(&digits[datetime.month() * 2], 2);
buf.write(&digits100[datetime.month() * 2], 2);
buf.write(date_delimeter);
buf.write(&digits[datetime.day() * 2], 2);
buf.write(&digits100[datetime.day() * 2], 2);
buf.write(between_date_time_delimiter);
buf.write(&digits[datetime.hour() * 2], 2);
buf.write(&digits100[datetime.hour() * 2], 2);
buf.write(time_delimeter);
buf.write(&digits[datetime.minute() * 2], 2);
buf.write(&digits100[datetime.minute() * 2], 2);
buf.write(time_delimeter);
buf.write(&digits[datetime.second() * 2], 2);
buf.write(&digits100[datetime.second() * 2], 2);
}
}
@ -707,6 +696,33 @@ inline void writeDateTimeText(time_t datetime, WriteBuffer & buf, const DateLUTI
}
/// In the RFC 1123 format: "Tue, 03 Dec 2019 00:11:50 GMT". You must provide GMT DateLUT.
/// This is needed for HTTP requests.
inline void writeDateTimeTextRFC1123(time_t datetime, WriteBuffer & buf, const DateLUTImpl & date_lut)
{
const auto & values = date_lut.getValues(datetime);
static const char week_days[3 * 8 + 1] = "XXX" "Mon" "Tue" "Wed" "Thu" "Fri" "Sat" "Sun";
static const char months[3 * 13 + 1] = "XXX" "Jan" "Feb" "Mar" "Apr" "May" "Jun" "Jul" "Aug" "Sep" "Oct" "Nov" "Dec";
buf.write(&week_days[values.day_of_week * 3], 3);
buf.write(", ", 2);
buf.write(&digits100[values.day_of_month * 2], 2);
buf.write(' ');
buf.write(&months[values.month * 3], 3);
buf.write(' ');
buf.write(&digits100[values.year / 100 * 2], 2);
buf.write(&digits100[values.year % 100 * 2], 2);
buf.write(' ');
buf.write(&digits100[date_lut.toHour(datetime) * 2], 2);
buf.write(':');
buf.write(&digits100[date_lut.toMinute(datetime) * 2], 2);
buf.write(':');
buf.write(&digits100[date_lut.toSecond(datetime) * 2], 2);
buf.write(" GMT", 4);
}
/// Methods for output in binary format.
template <typename T>
inline std::enable_if_t<is_arithmetic_v<T>, void>

View File

@ -0,0 +1,14 @@
#include <gtest/gtest.h>
#include <common/DateLUT.h>
#include <IO/WriteHelpers.h>
#include <IO/WriteBufferFromString.h>
TEST(RFC1123, Test)
{
using namespace DB;
WriteBufferFromOwnString out;
writeDateTimeTextRFC1123(1111111111, out, DateLUT::instance("UTC"));
ASSERT_EQ(out.str(), "Fri, 18 Mar 2005 01:58:31 GMT");
}

View File

@ -32,6 +32,8 @@ namespace
{
public:
StorageS3BlockInputStream(const Poco::URI & uri,
const String & access_key_id,
const String & secret_access_key,
const String & format,
const String & name_,
const Block & sample_block,
@ -41,7 +43,7 @@ namespace
const CompressionMethod compression_method)
: name(name_)
{
read_buf = getReadBuffer<ReadBufferFromS3>(compression_method, uri, timeouts);
read_buf = getReadBuffer<ReadBufferFromS3>(compression_method, uri, access_key_id, secret_access_key, timeouts);
reader = FormatFactory::instance().getInput(format, *read_buf, sample_block, context, max_block_size);
}
@ -80,6 +82,8 @@ namespace
{
public:
StorageS3BlockOutputStream(const Poco::URI & uri,
const String & access_key_id,
const String & secret_access_key,
const String & format,
UInt64 min_upload_part_size,
const Block & sample_block_,
@ -88,7 +92,13 @@ namespace
const CompressionMethod compression_method)
: sample_block(sample_block_)
{
write_buf = getWriteBuffer<WriteBufferFromS3>(compression_method, uri, min_upload_part_size, timeouts);
write_buf = getWriteBuffer<WriteBufferFromS3>(
compression_method,
uri,
access_key_id,
secret_access_key,
min_upload_part_size,
timeouts);
writer = FormatFactory::instance().getOutput(format, *write_buf, sample_block, context);
}
@ -124,6 +134,8 @@ namespace
StorageS3::StorageS3(
const Poco::URI & uri_,
const String & access_key_id_,
const String & secret_access_key_,
const std::string & database_name_,
const std::string & table_name_,
const String & format_name_,
@ -134,6 +146,8 @@ StorageS3::StorageS3(
const String & compression_method_ = "")
: IStorage(columns_)
, uri(uri_)
, access_key_id(access_key_id_)
, secret_access_key(secret_access_key_)
, context_global(context_)
, format_name(format_name_)
, database_name(database_name_)
@ -156,6 +170,8 @@ BlockInputStreams StorageS3::read(
{
BlockInputStreamPtr block_input = std::make_shared<StorageS3BlockInputStream>(
uri,
access_key_id,
secret_access_key,
format_name,
getName(),
getHeaderBlock(column_names),
@ -179,7 +195,13 @@ void StorageS3::rename(const String & /*new_path_to_db*/, const String & new_dat
BlockOutputStreamPtr StorageS3::write(const ASTPtr & /*query*/, const Context & /*context*/)
{
return std::make_shared<StorageS3BlockOutputStream>(
uri, format_name, min_upload_part_size, getSampleBlock(), context_global,
uri,
access_key_id,
secret_access_key,
format_name,
min_upload_part_size,
getSampleBlock(),
context_global,
ConnectionTimeouts::getHTTPTimeouts(context_global),
IStorage::chooseCompressionMethod(uri.toString(), compression_method));
}
@ -190,29 +212,35 @@ void registerStorageS3(StorageFactory & factory)
{
ASTs & engine_args = args.engine_args;
if (engine_args.size() != 2 && engine_args.size() != 3)
if (engine_args.size() < 2 || engine_args.size() > 5)
throw Exception(
"Storage S3 requires 2 or 3 arguments: url, name of used format and compression_method.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
"Storage S3 requires 2 to 5 arguments: url, [access_key_id, secret_access_key], name of used format and [compression_method].", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
engine_args[0] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[0], args.local_context);
for (size_t i = 0; i < engine_args.size(); ++i)
engine_args[i] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[i], args.local_context);
String url = engine_args[0]->as<ASTLiteral &>().value.safeGet<String>();
Poco::URI uri(url);
engine_args[1] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[1], args.local_context);
String format_name = engine_args[engine_args.size() - 1]->as<ASTLiteral &>().value.safeGet<String>();
String format_name = engine_args[1]->as<ASTLiteral &>().value.safeGet<String>();
String access_key_id;
String secret_access_key;
if (engine_args.size() >= 4)
{
access_key_id = engine_args[1]->as<ASTLiteral &>().value.safeGet<String>();
secret_access_key = engine_args[2]->as<ASTLiteral &>().value.safeGet<String>();
}
UInt64 min_upload_part_size = args.local_context.getSettingsRef().s3_min_upload_part_size;
String compression_method;
if (engine_args.size() == 3)
{
engine_args[2] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[2], args.local_context);
compression_method = engine_args[2]->as<ASTLiteral &>().value.safeGet<String>();
} else compression_method = "auto";
if (engine_args.size() == 3 || engine_args.size() == 5)
compression_method = engine_args.back()->as<ASTLiteral &>().value.safeGet<String>();
else
compression_method = "auto";
return StorageS3::create(uri, args.database_name, args.table_name, format_name, min_upload_part_size, args.columns, args.constraints, args.context);
return StorageS3::create(uri, access_key_id, secret_access_key, args.database_name, args.table_name, format_name, min_upload_part_size, args.columns, args.constraints, args.context);
});
}
}

View File

@ -18,8 +18,10 @@ class StorageS3 : public ext::shared_ptr_helper<StorageS3>, public IStorage
public:
StorageS3(
const Poco::URI & uri_,
const std::string & database_name_,
const std::string & table_name_,
const String & access_key_id,
const String & secret_access_key,
const String & database_name_,
const String & table_name_,
const String & format_name_,
UInt64 min_upload_part_size_,
const ColumnsDescription & columns_,
@ -56,6 +58,8 @@ public:
private:
Poco::URI uri;
String access_key_id;
String secret_access_key;
const Context & context_global;
String format_name;

View File

@ -1,17 +1,84 @@
#include <Storages/StorageS3.h>
#include <Interpreters/evaluateConstantExpression.h>
#include <TableFunctions/TableFunctionFactory.h>
#include <TableFunctions/TableFunctionS3.h>
#include <TableFunctions/parseColumnsListForTableFunction.h>
#include <Parsers/ASTLiteral.h>
#include <Poco/URI.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
StoragePtr TableFunctionS3::executeImpl(const ASTPtr & ast_function, const Context & context, const std::string & table_name) const
{
/// Parse args
ASTs & args_func = ast_function->children;
if (args_func.size() != 1)
throw Exception("Table function '" + getName() + "' must have arguments.", ErrorCodes::LOGICAL_ERROR);
ASTs & args = args_func.at(0)->children;
if (args.size() < 3 || args.size() > 6)
throw Exception("Table function '" + getName() + "' requires 3 to 6 arguments: url, [access_key_id, secret_access_key,] format, structure and [compression_method].",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
for (size_t i = 0; i < args.size(); ++i)
args[i] = evaluateConstantExpressionOrIdentifierAsLiteral(args[i], context);
String filename = args[0]->as<ASTLiteral &>().value.safeGet<String>();
String format;
String structure;
String access_key_id;
String secret_access_key;
if (args.size() < 5)
{
format = args[1]->as<ASTLiteral &>().value.safeGet<String>();
structure = args[2]->as<ASTLiteral &>().value.safeGet<String>();
}
else
{
access_key_id = args[1]->as<ASTLiteral &>().value.safeGet<String>();
secret_access_key = args[2]->as<ASTLiteral &>().value.safeGet<String>();
format = args[3]->as<ASTLiteral &>().value.safeGet<String>();
structure = args[4]->as<ASTLiteral &>().value.safeGet<String>();
}
String compression_method;
if (args.size() == 4 || args.size() == 6)
compression_method = args.back()->as<ASTLiteral &>().value.safeGet<String>();
else
compression_method = "auto";
ColumnsDescription columns = parseColumnsListFromString(structure, context);
/// Create table
StoragePtr storage = getStorage(filename, access_key_id, secret_access_key, format, columns, const_cast<Context &>(context), table_name, compression_method);
storage->startup();
return storage;
}
StoragePtr TableFunctionS3::getStorage(
const String & source, const String & format, const ColumnsDescription & columns, Context & global_context, const std::string & table_name, const String & compression_method) const
const String & source,
const String & access_key_id,
const String & secret_access_key,
const String & format,
const ColumnsDescription & columns,
Context & global_context,
const std::string & table_name,
const String & compression_method) const
{
Poco::URI uri(source);
UInt64 min_upload_part_size = global_context.getSettingsRef().s3_min_upload_part_size;
return StorageS3::create(uri, getDatabaseName(), table_name, format, min_upload_part_size, columns, ConstraintsDescription{}, global_context, compression_method);
return StorageS3::create(uri, access_key_id, secret_access_key, getDatabaseName(), table_name, format, min_upload_part_size, columns, ConstraintsDescription{}, global_context, compression_method);
}
void registerTableFunctionS3(TableFunctionFactory & factory)

View File

@ -1,6 +1,6 @@
#pragma once
#include <TableFunctions/ITableFunctionFileLike.h>
#include <TableFunctions/ITableFunction.h>
namespace DB
@ -8,9 +8,9 @@ namespace DB
class Context;
/* s3(source, format, structure) - creates a temporary storage for a file in S3
/* s3(source, [access_key_id, secret_access_key,] format, structure) - creates a temporary storage for a file in S3
*/
class TableFunctionS3 : public ITableFunctionFileLike
class TableFunctionS3 : public ITableFunction
{
public:
static constexpr auto name = "s3";
@ -20,13 +20,20 @@ public:
}
private:
StoragePtr executeImpl(
const ASTPtr & ast_function,
const Context & context,
const std::string & table_name) const override;
StoragePtr getStorage(
const String & source,
const String & access_key_id,
const String & secret_access_key,
const String & format,
const ColumnsDescription & columns,
Context & global_context,
const std::string & table_name,
const String & compression_method) const override;
const String & compression_method) const;
};
}

View File

@ -5,6 +5,9 @@ import pytest
from helpers.cluster import ClickHouseCluster, ClickHouseInstance
import helpers.client
logging.getLogger().setLevel(logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler())
@ -53,12 +56,18 @@ def prepare_s3_bucket(cluster):
minio_client.set_bucket_policy(cluster.minio_bucket, json.dumps(bucket_read_write_policy))
cluster.minio_restricted_bucket = "{}-with-auth".format(cluster.minio_bucket)
if minio_client.bucket_exists(cluster.minio_restricted_bucket):
minio_client.remove_bucket(cluster.minio_restricted_bucket)
minio_client.make_bucket(cluster.minio_restricted_bucket)
# Returns content of given S3 file as string.
def get_s3_file_content(cluster, filename):
def get_s3_file_content(cluster, bucket, filename):
# type: (ClickHouseCluster, str) -> str
data = cluster.minio_client.get_object(cluster.minio_bucket, filename)
data = cluster.minio_client.get_object(bucket, filename)
data_str = ""
for chunk in data.stream():
data_str += chunk
@ -101,53 +110,76 @@ def run_query(instance, query, stdin=None, settings=None):
# Test simple put.
def test_put(cluster):
@pytest.mark.parametrize("maybe_auth,positive", [
("",True),
("'minio','minio123',",True),
("'wrongid','wrongkey',",False)
])
def test_put(cluster, maybe_auth, positive):
# type: (ClickHouseCluster) -> None
bucket = cluster.minio_bucket if not maybe_auth else cluster.minio_restricted_bucket
instance = cluster.instances["dummy"] # type: ClickHouseInstance
table_format = "column1 UInt32, column2 UInt32, column3 UInt32"
values = "(1, 2, 3), (3, 2, 1), (78, 43, 45)"
values_csv = "1,2,3\n3,2,1\n78,43,45\n"
filename = "test.csv"
put_query = "insert into table function s3('http://{}:{}/{}/{}', 'CSV', '{}') values {}".format(
cluster.minio_host, cluster.minio_port, cluster.minio_bucket, filename, table_format, values)
run_query(instance, put_query)
put_query = "insert into table function s3('http://{}:{}/{}/{}', {}'CSV', '{}') values {}".format(
cluster.minio_host, cluster.minio_port, bucket, filename, maybe_auth, table_format, values)
assert values_csv == get_s3_file_content(cluster, filename)
try:
run_query(instance, put_query)
except helpers.client.QueryRuntimeException:
assert not positive
else:
assert positive
assert values_csv == get_s3_file_content(cluster, bucket, filename)
# Test put values in CSV format.
def test_put_csv(cluster):
@pytest.mark.parametrize("maybe_auth,positive", [
("",True),
("'minio','minio123',",True),
("'wrongid','wrongkey',",False)
])
def test_put_csv(cluster, maybe_auth, positive):
# type: (ClickHouseCluster) -> None
bucket = cluster.minio_bucket if not maybe_auth else cluster.minio_restricted_bucket
instance = cluster.instances["dummy"] # type: ClickHouseInstance
table_format = "column1 UInt32, column2 UInt32, column3 UInt32"
filename = "test.csv"
put_query = "insert into table function s3('http://{}:{}/{}/{}', 'CSV', '{}') format CSV".format(
cluster.minio_host, cluster.minio_port, cluster.minio_bucket, filename, table_format)
put_query = "insert into table function s3('http://{}:{}/{}/{}', {}'CSV', '{}') format CSV".format(
cluster.minio_host, cluster.minio_port, bucket, filename, maybe_auth, table_format)
csv_data = "8,9,16\n11,18,13\n22,14,2\n"
run_query(instance, put_query, stdin=csv_data)
assert csv_data == get_s3_file_content(cluster, filename)
try:
run_query(instance, put_query, stdin=csv_data)
except helpers.client.QueryRuntimeException:
assert not positive
else:
assert positive
assert csv_data == get_s3_file_content(cluster, bucket, filename)
# Test put and get with S3 server redirect.
def test_put_get_with_redirect(cluster):
# type: (ClickHouseCluster) -> None
bucket = cluster.minio_bucket
instance = cluster.instances["dummy"] # type: ClickHouseInstance
table_format = "column1 UInt32, column2 UInt32, column3 UInt32"
values = "(1, 1, 1), (1, 1, 1), (11, 11, 11)"
values_csv = "1,1,1\n1,1,1\n11,11,11\n"
filename = "test.csv"
query = "insert into table function s3('http://{}:{}/{}/{}', 'CSV', '{}') values {}".format(
cluster.minio_redirect_host, cluster.minio_redirect_port, cluster.minio_bucket, filename, table_format, values)
cluster.minio_redirect_host, cluster.minio_redirect_port, bucket, filename, table_format, values)
run_query(instance, query)
assert values_csv == get_s3_file_content(cluster, filename)
assert values_csv == get_s3_file_content(cluster, bucket, filename)
query = "select *, column1*column2*column3 from s3('http://{}:{}/{}/{}', 'CSV', '{}')".format(
cluster.minio_redirect_host, cluster.minio_redirect_port, cluster.minio_bucket, filename, table_format)
cluster.minio_redirect_host, cluster.minio_redirect_port, bucket, filename, table_format)
stdout = run_query(instance, query)
assert list(map(str.split, stdout.splitlines())) == [
@ -158,9 +190,15 @@ def test_put_get_with_redirect(cluster):
# Test multipart put.
def test_multipart_put(cluster):
@pytest.mark.parametrize("maybe_auth,positive", [
("",True),
("'minio','minio123',",True),
("'wrongid','wrongkey',",False)
])
def test_multipart_put(cluster, maybe_auth, positive):
# type: (ClickHouseCluster) -> None
bucket = cluster.minio_bucket if not maybe_auth else cluster.minio_restricted_bucket
instance = cluster.instances["dummy"] # type: ClickHouseInstance
table_format = "column1 UInt32, column2 UInt32, column3 UInt32"
@ -178,14 +216,19 @@ def test_multipart_put(cluster):
assert len(csv_data) > min_part_size_bytes
filename = "test_multipart.csv"
put_query = "insert into table function s3('http://{}:{}/{}/{}', 'CSV', '{}') format CSV".format(
cluster.minio_redirect_host, cluster.minio_redirect_port, cluster.minio_bucket, filename, table_format)
put_query = "insert into table function s3('http://{}:{}/{}/{}', {}'CSV', '{}') format CSV".format(
cluster.minio_redirect_host, cluster.minio_redirect_port, bucket, filename, maybe_auth, table_format)
run_query(instance, put_query, stdin=csv_data, settings={'s3_min_upload_part_size': min_part_size_bytes})
try:
run_query(instance, put_query, stdin=csv_data, settings={'s3_min_upload_part_size': min_part_size_bytes})
except helpers.client.QueryRuntimeException:
assert not positive
else:
assert positive
# Use Nginx access logs to count number of parts uploaded to Minio.
nginx_logs = get_nginx_access_logs()
uploaded_parts = filter(lambda log_line: log_line.find(filename) >= 0 and log_line.find("PUT") >= 0, nginx_logs)
assert uploaded_parts > 1
# Use Nginx access logs to count number of parts uploaded to Minio.
nginx_logs = get_nginx_access_logs()
uploaded_parts = filter(lambda log_line: log_line.find(filename) >= 0 and log_line.find("PUT") >= 0, nginx_logs)
assert uploaded_parts > 1
assert csv_data == get_s3_file_content(cluster, filename)
assert csv_data == get_s3_file_content(cluster, bucket, filename)

View File

@ -2,7 +2,7 @@ drop table if exists test_table_s3_syntax
;
create table test_table_s3_syntax (id UInt32) ENGINE = S3('')
; -- { serverError 42 }
create table test_table_s3_syntax (id UInt32) ENGINE = S3('','','','')
create table test_table_s3_syntax (id UInt32) ENGINE = S3('','','','','','')
; -- { serverError 42 }
drop table if exists test_table_s3_syntax
;