diff --git a/dbms/src/Common/ErrorCodes.cpp b/dbms/src/Common/ErrorCodes.cpp index 4f49ca92df4..0ea287a01e9 100644 --- a/dbms/src/Common/ErrorCodes.cpp +++ b/dbms/src/Common/ErrorCodes.cpp @@ -464,6 +464,7 @@ namespace ErrorCodes extern const int CANNOT_GET_CREATE_DICTIONARY_QUERY = 487; extern const int UNKNOWN_DICTIONARY = 488; extern const int INCORRECT_DICTIONARY_DEFINITION = 489; + extern const int CANNOT_FORMAT_DATETIME = 490; extern const int KEEPER_EXCEPTION = 999; extern const int POCO_EXCEPTION = 1000; diff --git a/dbms/src/Functions/formatDateTime.cpp b/dbms/src/Functions/formatDateTime.cpp index 8cecdb69717..c7150515935 100644 --- a/dbms/src/Functions/formatDateTime.cpp +++ b/dbms/src/Functions/formatDateTime.cpp @@ -91,19 +91,7 @@ private: template static inline void writeNumber2(char * p, T v) { - static const char digits[201] = - "00010203040506070809" - "10111213141516171819" - "20212223242526272829" - "30313233343536373839" - "40414243444546474849" - "50515253545556575859" - "60616263646566676869" - "70717273747576777879" - "80818283848586878889" - "90919293949596979899"; - - memcpy(p, &digits[v * 2], 2); + memcpy(p, &digits100[v * 2], 2); } template diff --git a/dbms/src/IO/ReadBufferFromS3.cpp b/dbms/src/IO/ReadBufferFromS3.cpp index ae09f0fb189..b26a8b8c316 100644 --- a/dbms/src/IO/ReadBufferFromS3.cpp +++ b/dbms/src/IO/ReadBufferFromS3.cpp @@ -1,6 +1,7 @@ #include #include +#include #include @@ -10,13 +11,12 @@ namespace DB const int DEFAULT_S3_MAX_FOLLOW_GET_REDIRECT = 2; -ReadBufferFromS3::ReadBufferFromS3(Poco::URI uri_, - const ConnectionTimeouts & timeouts, - const Poco::Net::HTTPBasicCredentials & credentials, - size_t buffer_size_) +ReadBufferFromS3::ReadBufferFromS3(const Poco::URI & uri_, + const String & access_key_id_, + const String & secret_access_key_, + const ConnectionTimeouts & timeouts) : ReadBuffer(nullptr, 0) , uri {uri_} - , method {Poco::Net::HTTPRequest::HTTP_GET} , session {makeHTTPSession(uri_, timeouts)} { Poco::Net::HTTPResponse response; @@ -28,11 +28,13 @@ ReadBufferFromS3::ReadBufferFromS3(Poco::URI uri_, if (uri.getPath().empty()) uri.setPath("/"); - request = std::make_unique(method, uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1); + request = std::make_unique( + Poco::Net::HTTPRequest::HTTP_GET, + uri.getPathAndQuery(), + Poco::Net::HTTPRequest::HTTP_1_1); request->setHost(uri.getHost()); // use original, not resolved host name in header - if (!credentials.getUsername().empty()) - credentials.authenticate(*request); + S3Helper::authenticateRequest(*request, access_key_id_, secret_access_key_); LOG_TRACE((&Logger::get("ReadBufferFromS3")), "Sending request to " << uri.toString()); @@ -54,7 +56,7 @@ ReadBufferFromS3::ReadBufferFromS3(Poco::URI uri_, } assertResponseIsOk(*request, response, *istr); - impl = std::make_unique(*istr, buffer_size_); + impl = std::make_unique(*istr, DBMS_DEFAULT_BUFFER_SIZE); } diff --git a/dbms/src/IO/ReadBufferFromS3.h b/dbms/src/IO/ReadBufferFromS3.h index ffc0c5c0ab1..071ee7802a2 100644 --- a/dbms/src/IO/ReadBufferFromS3.h +++ b/dbms/src/IO/ReadBufferFromS3.h @@ -17,17 +17,15 @@ class ReadBufferFromS3 : public ReadBuffer { protected: Poco::URI uri; - std::string method; - HTTPSessionPtr session; std::istream * istr; /// owned by session std::unique_ptr impl; public: - explicit ReadBufferFromS3(Poco::URI uri_, - const ConnectionTimeouts & timeouts = {}, - const Poco::Net::HTTPBasicCredentials & credentials = {}, - size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE); + explicit ReadBufferFromS3(const Poco::URI & uri_, + const String & access_key_id_, + const String & secret_access_key_, + const ConnectionTimeouts & timeouts = {}); bool nextImpl() override; }; diff --git a/dbms/src/IO/S3Common.cpp b/dbms/src/IO/S3Common.cpp new file mode 100644 index 00000000000..1233bae38e1 --- /dev/null +++ b/dbms/src/IO/S3Common.cpp @@ -0,0 +1,60 @@ +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int CANNOT_FORMAT_DATETIME; +} + +void S3Helper::authenticateRequest( + Poco::Net::HTTPRequest & request, + const String & access_key_id, + const String & secret_access_key) +{ + /// See https://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html + + if (access_key_id.empty()) + return; + + /// Limitations: + /// 1. Virtual hosted-style requests are not supported (e.g. `http://johnsmith.net.s3.amazonaws.com/homepage.html`). + /// 2. AMZ headers are not supported (TODO). + + if (!request.has("Date")) + { + WriteBufferFromOwnString out; + writeDateTimeTextRFC1123(time(nullptr), out, DateLUT::instance("UTC")); + request.set("Date", out.str()); + } + + String string_to_sign = request.getMethod() + "\n" + + request.get("Content-MD5", "") + "\n" + + request.get("Content-Type", "") + "\n" + + request.get("Date") + "\n" + + Poco::URI(request.getURI()).getPathAndQuery(); + + Poco::HMACEngine engine(secret_access_key); + engine.update(string_to_sign); + auto digest = engine.digest(); + std::ostringstream signature; + Poco::Base64Encoder encoder(signature); + std::copy(digest.begin(), digest.end(), std::ostream_iterator(encoder)); + encoder.close(); + + request.set("Authorization", "AWS " + access_key_id + ":" + signature.str()); +} + +} diff --git a/dbms/src/IO/S3Common.h b/dbms/src/IO/S3Common.h new file mode 100644 index 00000000000..b68f5c9b536 --- /dev/null +++ b/dbms/src/IO/S3Common.h @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include + + +namespace DB +{ + +namespace S3Helper +{ + void authenticateRequest( + Poco::Net::HTTPRequest & request, + const String & access_key_id, + const String & secret_access_key); +}; + +} diff --git a/dbms/src/IO/WriteBufferFromS3.cpp b/dbms/src/IO/WriteBufferFromS3.cpp index 9604b6ce199..4154db48282 100644 --- a/dbms/src/IO/WriteBufferFromS3.cpp +++ b/dbms/src/IO/WriteBufferFromS3.cpp @@ -1,5 +1,6 @@ #include +#include #include #include @@ -30,22 +31,22 @@ namespace ErrorCodes WriteBufferFromS3::WriteBufferFromS3( const Poco::URI & uri_, + const String & access_key_id_, + const String & secret_access_key_, size_t minimum_upload_part_size_, - const ConnectionTimeouts & timeouts_, - const Poco::Net::HTTPBasicCredentials & credentials, size_t buffer_size_ -) - : BufferWithOwnMemory(buffer_size_, nullptr, 0) + const ConnectionTimeouts & timeouts_) + : BufferWithOwnMemory(DBMS_DEFAULT_BUFFER_SIZE, nullptr, 0) , uri {uri_} + , access_key_id {access_key_id_} + , secret_access_key {secret_access_key_} , minimum_upload_part_size {minimum_upload_part_size_} , timeouts {timeouts_} - , auth_request {Poco::Net::HTTPRequest::HTTP_PUT, uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1} , temporary_buffer {std::make_unique(buffer_string)} , last_part_size {0} { - if (!credentials.getUsername().empty()) - credentials.authenticate(auth_request); - initiate(); + + /// FIXME: Implement rest of S3 authorization. } @@ -113,11 +114,7 @@ void WriteBufferFromS3::initiate() request_ptr = std::make_unique(Poco::Net::HTTPRequest::HTTP_POST, initiate_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1); request_ptr->setHost(initiate_uri.getHost()); // use original, not resolved host name in header - if (auth_request.hasCredentials()) - { - Poco::Net::HTTPBasicCredentials credentials(auth_request); - credentials.authenticate(*request_ptr); - } + S3Helper::authenticateRequest(*request_ptr, access_key_id, secret_access_key); request_ptr->setContentLength(0); @@ -179,11 +176,7 @@ void WriteBufferFromS3::writePart(const String & data) request_ptr = std::make_unique(Poco::Net::HTTPRequest::HTTP_PUT, part_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1); request_ptr->setHost(part_uri.getHost()); // use original, not resolved host name in header - if (auth_request.hasCredentials()) - { - Poco::Net::HTTPBasicCredentials credentials(auth_request); - credentials.authenticate(*request_ptr); - } + S3Helper::authenticateRequest(*request_ptr, access_key_id, secret_access_key); request_ptr->setExpectContinue(true); @@ -252,11 +245,7 @@ void WriteBufferFromS3::complete() request_ptr = std::make_unique(Poco::Net::HTTPRequest::HTTP_POST, complete_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1); request_ptr->setHost(complete_uri.getHost()); // use original, not resolved host name in header - if (auth_request.hasCredentials()) - { - Poco::Net::HTTPBasicCredentials credentials(auth_request); - credentials.authenticate(*request_ptr); - } + S3Helper::authenticateRequest(*request_ptr, access_key_id, secret_access_key); request_ptr->setExpectContinue(true); diff --git a/dbms/src/IO/WriteBufferFromS3.h b/dbms/src/IO/WriteBufferFromS3.h index 9a619f8c8bc..6f89f7c36ec 100644 --- a/dbms/src/IO/WriteBufferFromS3.h +++ b/dbms/src/IO/WriteBufferFromS3.h @@ -21,9 +21,10 @@ class WriteBufferFromS3 : public BufferWithOwnMemory { private: Poco::URI uri; + String access_key_id; + String secret_access_key; size_t minimum_upload_part_size; ConnectionTimeouts timeouts; - Poco::Net::HTTPRequest auth_request; String buffer_string; std::unique_ptr temporary_buffer; size_t last_part_size; @@ -35,10 +36,10 @@ private: public: explicit WriteBufferFromS3(const Poco::URI & uri, + const String & access_key_id, + const String & secret_access_key, size_t minimum_upload_part_size_, - const ConnectionTimeouts & timeouts = {}, - const Poco::Net::HTTPBasicCredentials & credentials = {}, - size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE); + const ConnectionTimeouts & timeouts = {}); void nextImpl() override; diff --git a/dbms/src/IO/WriteHelpers.h b/dbms/src/IO/WriteHelpers.h index 0163a3c3740..509c37257ad 100644 --- a/dbms/src/IO/WriteHelpers.h +++ b/dbms/src/IO/WriteHelpers.h @@ -568,45 +568,46 @@ inline void writeUUIDText(const UUID & uuid, WriteBuffer & buf) buf.write(s, sizeof(s)); } + +static const char digits100[201] = + "00010203040506070809" + "10111213141516171819" + "20212223242526272829" + "30313233343536373839" + "40414243444546474849" + "50515253545556575859" + "60616263646566676869" + "70717273747576777879" + "80818283848586878889" + "90919293949596979899"; + /// in YYYY-MM-DD format template inline void writeDateText(const LocalDate & date, WriteBuffer & buf) { - static const char digits[201] = - "00010203040506070809" - "10111213141516171819" - "20212223242526272829" - "30313233343536373839" - "40414243444546474849" - "50515253545556575859" - "60616263646566676869" - "70717273747576777879" - "80818283848586878889" - "90919293949596979899"; - if (buf.position() + 10 <= buf.buffer().end()) { - memcpy(buf.position(), &digits[date.year() / 100 * 2], 2); + memcpy(buf.position(), &digits100[date.year() / 100 * 2], 2); buf.position() += 2; - memcpy(buf.position(), &digits[date.year() % 100 * 2], 2); + memcpy(buf.position(), &digits100[date.year() % 100 * 2], 2); buf.position() += 2; *buf.position() = delimiter; ++buf.position(); - memcpy(buf.position(), &digits[date.month() * 2], 2); + memcpy(buf.position(), &digits100[date.month() * 2], 2); buf.position() += 2; *buf.position() = delimiter; ++buf.position(); - memcpy(buf.position(), &digits[date.day() * 2], 2); + memcpy(buf.position(), &digits100[date.day() * 2], 2); buf.position() += 2; } else { - buf.write(&digits[date.year() / 100 * 2], 2); - buf.write(&digits[date.year() % 100 * 2], 2); + buf.write(&digits100[date.year() / 100 * 2], 2); + buf.write(&digits100[date.year() % 100 * 2], 2); buf.write(delimiter); - buf.write(&digits[date.month() * 2], 2); + buf.write(&digits100[date.month() * 2], 2); buf.write(delimiter); - buf.write(&digits[date.day() * 2], 2); + buf.write(&digits100[date.day() * 2], 2); } } @@ -628,59 +629,47 @@ inline void writeDateText(DayNum date, WriteBuffer & buf) template inline void writeDateTimeText(const LocalDateTime & datetime, WriteBuffer & buf) { - static const char digits[201] = - "00010203040506070809" - "10111213141516171819" - "20212223242526272829" - "30313233343536373839" - "40414243444546474849" - "50515253545556575859" - "60616263646566676869" - "70717273747576777879" - "80818283848586878889" - "90919293949596979899"; - if (buf.position() + 19 <= buf.buffer().end()) { - memcpy(buf.position(), &digits[datetime.year() / 100 * 2], 2); + memcpy(buf.position(), &digits100[datetime.year() / 100 * 2], 2); buf.position() += 2; - memcpy(buf.position(), &digits[datetime.year() % 100 * 2], 2); + memcpy(buf.position(), &digits100[datetime.year() % 100 * 2], 2); buf.position() += 2; *buf.position() = date_delimeter; ++buf.position(); - memcpy(buf.position(), &digits[datetime.month() * 2], 2); + memcpy(buf.position(), &digits100[datetime.month() * 2], 2); buf.position() += 2; *buf.position() = date_delimeter; ++buf.position(); - memcpy(buf.position(), &digits[datetime.day() * 2], 2); + memcpy(buf.position(), &digits100[datetime.day() * 2], 2); buf.position() += 2; *buf.position() = between_date_time_delimiter; ++buf.position(); - memcpy(buf.position(), &digits[datetime.hour() * 2], 2); + memcpy(buf.position(), &digits100[datetime.hour() * 2], 2); buf.position() += 2; *buf.position() = time_delimeter; ++buf.position(); - memcpy(buf.position(), &digits[datetime.minute() * 2], 2); + memcpy(buf.position(), &digits100[datetime.minute() * 2], 2); buf.position() += 2; *buf.position() = time_delimeter; ++buf.position(); - memcpy(buf.position(), &digits[datetime.second() * 2], 2); + memcpy(buf.position(), &digits100[datetime.second() * 2], 2); buf.position() += 2; } else { - buf.write(&digits[datetime.year() / 100 * 2], 2); - buf.write(&digits[datetime.year() % 100 * 2], 2); + buf.write(&digits100[datetime.year() / 100 * 2], 2); + buf.write(&digits100[datetime.year() % 100 * 2], 2); buf.write(date_delimeter); - buf.write(&digits[datetime.month() * 2], 2); + buf.write(&digits100[datetime.month() * 2], 2); buf.write(date_delimeter); - buf.write(&digits[datetime.day() * 2], 2); + buf.write(&digits100[datetime.day() * 2], 2); buf.write(between_date_time_delimiter); - buf.write(&digits[datetime.hour() * 2], 2); + buf.write(&digits100[datetime.hour() * 2], 2); buf.write(time_delimeter); - buf.write(&digits[datetime.minute() * 2], 2); + buf.write(&digits100[datetime.minute() * 2], 2); buf.write(time_delimeter); - buf.write(&digits[datetime.second() * 2], 2); + buf.write(&digits100[datetime.second() * 2], 2); } } @@ -707,6 +696,33 @@ inline void writeDateTimeText(time_t datetime, WriteBuffer & buf, const DateLUTI } +/// In the RFC 1123 format: "Tue, 03 Dec 2019 00:11:50 GMT". You must provide GMT DateLUT. +/// This is needed for HTTP requests. +inline void writeDateTimeTextRFC1123(time_t datetime, WriteBuffer & buf, const DateLUTImpl & date_lut) +{ + const auto & values = date_lut.getValues(datetime); + + static const char week_days[3 * 8 + 1] = "XXX" "Mon" "Tue" "Wed" "Thu" "Fri" "Sat" "Sun"; + static const char months[3 * 13 + 1] = "XXX" "Jan" "Feb" "Mar" "Apr" "May" "Jun" "Jul" "Aug" "Sep" "Oct" "Nov" "Dec"; + + buf.write(&week_days[values.day_of_week * 3], 3); + buf.write(", ", 2); + buf.write(&digits100[values.day_of_month * 2], 2); + buf.write(' '); + buf.write(&months[values.month * 3], 3); + buf.write(' '); + buf.write(&digits100[values.year / 100 * 2], 2); + buf.write(&digits100[values.year % 100 * 2], 2); + buf.write(' '); + buf.write(&digits100[date_lut.toHour(datetime) * 2], 2); + buf.write(':'); + buf.write(&digits100[date_lut.toMinute(datetime) * 2], 2); + buf.write(':'); + buf.write(&digits100[date_lut.toSecond(datetime) * 2], 2); + buf.write(" GMT", 4); +} + + /// Methods for output in binary format. template inline std::enable_if_t, void> diff --git a/dbms/src/IO/tests/gtest_rfc1123.cpp b/dbms/src/IO/tests/gtest_rfc1123.cpp new file mode 100644 index 00000000000..66d7484de1f --- /dev/null +++ b/dbms/src/IO/tests/gtest_rfc1123.cpp @@ -0,0 +1,14 @@ +#include + +#include +#include +#include + + +TEST(RFC1123, Test) +{ + using namespace DB; + WriteBufferFromOwnString out; + writeDateTimeTextRFC1123(1111111111, out, DateLUT::instance("UTC")); + ASSERT_EQ(out.str(), "Fri, 18 Mar 2005 01:58:31 GMT"); +} diff --git a/dbms/src/Storages/StorageS3.cpp b/dbms/src/Storages/StorageS3.cpp index ed9173c52ec..df7313805d9 100644 --- a/dbms/src/Storages/StorageS3.cpp +++ b/dbms/src/Storages/StorageS3.cpp @@ -32,6 +32,8 @@ namespace { public: StorageS3BlockInputStream(const Poco::URI & uri, + const String & access_key_id, + const String & secret_access_key, const String & format, const String & name_, const Block & sample_block, @@ -41,7 +43,7 @@ namespace const CompressionMethod compression_method) : name(name_) { - read_buf = getReadBuffer(compression_method, uri, timeouts); + read_buf = getReadBuffer(compression_method, uri, access_key_id, secret_access_key, timeouts); reader = FormatFactory::instance().getInput(format, *read_buf, sample_block, context, max_block_size); } @@ -80,6 +82,8 @@ namespace { public: StorageS3BlockOutputStream(const Poco::URI & uri, + const String & access_key_id, + const String & secret_access_key, const String & format, UInt64 min_upload_part_size, const Block & sample_block_, @@ -88,7 +92,13 @@ namespace const CompressionMethod compression_method) : sample_block(sample_block_) { - write_buf = getWriteBuffer(compression_method, uri, min_upload_part_size, timeouts); + write_buf = getWriteBuffer( + compression_method, + uri, + access_key_id, + secret_access_key, + min_upload_part_size, + timeouts); writer = FormatFactory::instance().getOutput(format, *write_buf, sample_block, context); } @@ -124,6 +134,8 @@ namespace StorageS3::StorageS3( const Poco::URI & uri_, + const String & access_key_id_, + const String & secret_access_key_, const std::string & database_name_, const std::string & table_name_, const String & format_name_, @@ -134,6 +146,8 @@ StorageS3::StorageS3( const String & compression_method_ = "") : IStorage(columns_) , uri(uri_) + , access_key_id(access_key_id_) + , secret_access_key(secret_access_key_) , context_global(context_) , format_name(format_name_) , database_name(database_name_) @@ -156,6 +170,8 @@ BlockInputStreams StorageS3::read( { BlockInputStreamPtr block_input = std::make_shared( uri, + access_key_id, + secret_access_key, format_name, getName(), getHeaderBlock(column_names), @@ -179,7 +195,13 @@ void StorageS3::rename(const String & /*new_path_to_db*/, const String & new_dat BlockOutputStreamPtr StorageS3::write(const ASTPtr & /*query*/, const Context & /*context*/) { return std::make_shared( - uri, format_name, min_upload_part_size, getSampleBlock(), context_global, + uri, + access_key_id, + secret_access_key, + format_name, + min_upload_part_size, + getSampleBlock(), + context_global, ConnectionTimeouts::getHTTPTimeouts(context_global), IStorage::chooseCompressionMethod(uri.toString(), compression_method)); } @@ -190,29 +212,35 @@ void registerStorageS3(StorageFactory & factory) { ASTs & engine_args = args.engine_args; - if (engine_args.size() != 2 && engine_args.size() != 3) + if (engine_args.size() < 2 || engine_args.size() > 5) throw Exception( - "Storage S3 requires 2 or 3 arguments: url, name of used format and compression_method.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + "Storage S3 requires 2 to 5 arguments: url, [access_key_id, secret_access_key], name of used format and [compression_method].", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - engine_args[0] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[0], args.local_context); + for (size_t i = 0; i < engine_args.size(); ++i) + engine_args[i] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[i], args.local_context); String url = engine_args[0]->as().value.safeGet(); Poco::URI uri(url); - engine_args[1] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[1], args.local_context); + String format_name = engine_args[engine_args.size() - 1]->as().value.safeGet(); - String format_name = engine_args[1]->as().value.safeGet(); + String access_key_id; + String secret_access_key; + if (engine_args.size() >= 4) + { + access_key_id = engine_args[1]->as().value.safeGet(); + secret_access_key = engine_args[2]->as().value.safeGet(); + } UInt64 min_upload_part_size = args.local_context.getSettingsRef().s3_min_upload_part_size; String compression_method; - if (engine_args.size() == 3) - { - engine_args[2] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[2], args.local_context); - compression_method = engine_args[2]->as().value.safeGet(); - } else compression_method = "auto"; + if (engine_args.size() == 3 || engine_args.size() == 5) + compression_method = engine_args.back()->as().value.safeGet(); + else + compression_method = "auto"; - return StorageS3::create(uri, args.database_name, args.table_name, format_name, min_upload_part_size, args.columns, args.constraints, args.context); + return StorageS3::create(uri, access_key_id, secret_access_key, args.database_name, args.table_name, format_name, min_upload_part_size, args.columns, args.constraints, args.context); }); } } diff --git a/dbms/src/Storages/StorageS3.h b/dbms/src/Storages/StorageS3.h index 88b470ac2ac..4a5288271a2 100644 --- a/dbms/src/Storages/StorageS3.h +++ b/dbms/src/Storages/StorageS3.h @@ -18,8 +18,10 @@ class StorageS3 : public ext::shared_ptr_helper, public IStorage public: StorageS3( const Poco::URI & uri_, - const std::string & database_name_, - const std::string & table_name_, + const String & access_key_id, + const String & secret_access_key, + const String & database_name_, + const String & table_name_, const String & format_name_, UInt64 min_upload_part_size_, const ColumnsDescription & columns_, @@ -56,6 +58,8 @@ public: private: Poco::URI uri; + String access_key_id; + String secret_access_key; const Context & context_global; String format_name; diff --git a/dbms/src/TableFunctions/TableFunctionS3.cpp b/dbms/src/TableFunctions/TableFunctionS3.cpp index a9ee5ebf691..d203801d9c1 100644 --- a/dbms/src/TableFunctions/TableFunctionS3.cpp +++ b/dbms/src/TableFunctions/TableFunctionS3.cpp @@ -1,17 +1,84 @@ #include +#include #include #include +#include +#include #include namespace DB { +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} + +StoragePtr TableFunctionS3::executeImpl(const ASTPtr & ast_function, const Context & context, const std::string & table_name) const +{ + /// Parse args + ASTs & args_func = ast_function->children; + + if (args_func.size() != 1) + throw Exception("Table function '" + getName() + "' must have arguments.", ErrorCodes::LOGICAL_ERROR); + + ASTs & args = args_func.at(0)->children; + + if (args.size() < 3 || args.size() > 6) + throw Exception("Table function '" + getName() + "' requires 3 to 6 arguments: url, [access_key_id, secret_access_key,] format, structure and [compression_method].", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + for (size_t i = 0; i < args.size(); ++i) + args[i] = evaluateConstantExpressionOrIdentifierAsLiteral(args[i], context); + + String filename = args[0]->as().value.safeGet(); + String format; + String structure; + String access_key_id; + String secret_access_key; + + if (args.size() < 5) + { + format = args[1]->as().value.safeGet(); + structure = args[2]->as().value.safeGet(); + } + else + { + access_key_id = args[1]->as().value.safeGet(); + secret_access_key = args[2]->as().value.safeGet(); + format = args[3]->as().value.safeGet(); + structure = args[4]->as().value.safeGet(); + } + + String compression_method; + if (args.size() == 4 || args.size() == 6) + compression_method = args.back()->as().value.safeGet(); + else + compression_method = "auto"; + + ColumnsDescription columns = parseColumnsListFromString(structure, context); + + /// Create table + StoragePtr storage = getStorage(filename, access_key_id, secret_access_key, format, columns, const_cast(context), table_name, compression_method); + + storage->startup(); + + return storage; +} + StoragePtr TableFunctionS3::getStorage( - const String & source, const String & format, const ColumnsDescription & columns, Context & global_context, const std::string & table_name, const String & compression_method) const + const String & source, + const String & access_key_id, + const String & secret_access_key, + const String & format, + const ColumnsDescription & columns, + Context & global_context, + const std::string & table_name, + const String & compression_method) const { Poco::URI uri(source); UInt64 min_upload_part_size = global_context.getSettingsRef().s3_min_upload_part_size; - return StorageS3::create(uri, getDatabaseName(), table_name, format, min_upload_part_size, columns, ConstraintsDescription{}, global_context, compression_method); + return StorageS3::create(uri, access_key_id, secret_access_key, getDatabaseName(), table_name, format, min_upload_part_size, columns, ConstraintsDescription{}, global_context, compression_method); } void registerTableFunctionS3(TableFunctionFactory & factory) diff --git a/dbms/src/TableFunctions/TableFunctionS3.h b/dbms/src/TableFunctions/TableFunctionS3.h index 2f14e0319d4..0c81e0ed2a7 100644 --- a/dbms/src/TableFunctions/TableFunctionS3.h +++ b/dbms/src/TableFunctions/TableFunctionS3.h @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace DB @@ -8,9 +8,9 @@ namespace DB class Context; -/* s3(source, format, structure) - creates a temporary storage for a file in S3 +/* s3(source, [access_key_id, secret_access_key,] format, structure) - creates a temporary storage for a file in S3 */ -class TableFunctionS3 : public ITableFunctionFileLike +class TableFunctionS3 : public ITableFunction { public: static constexpr auto name = "s3"; @@ -20,13 +20,20 @@ public: } private: + StoragePtr executeImpl( + const ASTPtr & ast_function, + const Context & context, + const std::string & table_name) const override; + StoragePtr getStorage( const String & source, + const String & access_key_id, + const String & secret_access_key, const String & format, const ColumnsDescription & columns, Context & global_context, const std::string & table_name, - const String & compression_method) const override; + const String & compression_method) const; }; } diff --git a/dbms/tests/integration/test_storage_s3/test.py b/dbms/tests/integration/test_storage_s3/test.py index 1db472e3019..ed447274e86 100644 --- a/dbms/tests/integration/test_storage_s3/test.py +++ b/dbms/tests/integration/test_storage_s3/test.py @@ -5,6 +5,9 @@ import pytest from helpers.cluster import ClickHouseCluster, ClickHouseInstance +import helpers.client + + logging.getLogger().setLevel(logging.INFO) logging.getLogger().addHandler(logging.StreamHandler()) @@ -53,12 +56,18 @@ def prepare_s3_bucket(cluster): minio_client.set_bucket_policy(cluster.minio_bucket, json.dumps(bucket_read_write_policy)) + cluster.minio_restricted_bucket = "{}-with-auth".format(cluster.minio_bucket) + if minio_client.bucket_exists(cluster.minio_restricted_bucket): + minio_client.remove_bucket(cluster.minio_restricted_bucket) + + minio_client.make_bucket(cluster.minio_restricted_bucket) + # Returns content of given S3 file as string. -def get_s3_file_content(cluster, filename): +def get_s3_file_content(cluster, bucket, filename): # type: (ClickHouseCluster, str) -> str - data = cluster.minio_client.get_object(cluster.minio_bucket, filename) + data = cluster.minio_client.get_object(bucket, filename) data_str = "" for chunk in data.stream(): data_str += chunk @@ -101,53 +110,76 @@ def run_query(instance, query, stdin=None, settings=None): # Test simple put. -def test_put(cluster): +@pytest.mark.parametrize("maybe_auth,positive", [ + ("",True), + ("'minio','minio123',",True), + ("'wrongid','wrongkey',",False) +]) +def test_put(cluster, maybe_auth, positive): # type: (ClickHouseCluster) -> None + bucket = cluster.minio_bucket if not maybe_auth else cluster.minio_restricted_bucket instance = cluster.instances["dummy"] # type: ClickHouseInstance table_format = "column1 UInt32, column2 UInt32, column3 UInt32" values = "(1, 2, 3), (3, 2, 1), (78, 43, 45)" values_csv = "1,2,3\n3,2,1\n78,43,45\n" filename = "test.csv" - put_query = "insert into table function s3('http://{}:{}/{}/{}', 'CSV', '{}') values {}".format( - cluster.minio_host, cluster.minio_port, cluster.minio_bucket, filename, table_format, values) - run_query(instance, put_query) + put_query = "insert into table function s3('http://{}:{}/{}/{}', {}'CSV', '{}') values {}".format( + cluster.minio_host, cluster.minio_port, bucket, filename, maybe_auth, table_format, values) - assert values_csv == get_s3_file_content(cluster, filename) + try: + run_query(instance, put_query) + except helpers.client.QueryRuntimeException: + assert not positive + else: + assert positive + assert values_csv == get_s3_file_content(cluster, bucket, filename) # Test put values in CSV format. -def test_put_csv(cluster): +@pytest.mark.parametrize("maybe_auth,positive", [ + ("",True), + ("'minio','minio123',",True), + ("'wrongid','wrongkey',",False) +]) +def test_put_csv(cluster, maybe_auth, positive): # type: (ClickHouseCluster) -> None + bucket = cluster.minio_bucket if not maybe_auth else cluster.minio_restricted_bucket instance = cluster.instances["dummy"] # type: ClickHouseInstance table_format = "column1 UInt32, column2 UInt32, column3 UInt32" filename = "test.csv" - put_query = "insert into table function s3('http://{}:{}/{}/{}', 'CSV', '{}') format CSV".format( - cluster.minio_host, cluster.minio_port, cluster.minio_bucket, filename, table_format) + put_query = "insert into table function s3('http://{}:{}/{}/{}', {}'CSV', '{}') format CSV".format( + cluster.minio_host, cluster.minio_port, bucket, filename, maybe_auth, table_format) csv_data = "8,9,16\n11,18,13\n22,14,2\n" - run_query(instance, put_query, stdin=csv_data) - assert csv_data == get_s3_file_content(cluster, filename) + try: + run_query(instance, put_query, stdin=csv_data) + except helpers.client.QueryRuntimeException: + assert not positive + else: + assert positive + assert csv_data == get_s3_file_content(cluster, bucket, filename) # Test put and get with S3 server redirect. def test_put_get_with_redirect(cluster): # type: (ClickHouseCluster) -> None + bucket = cluster.minio_bucket instance = cluster.instances["dummy"] # type: ClickHouseInstance table_format = "column1 UInt32, column2 UInt32, column3 UInt32" values = "(1, 1, 1), (1, 1, 1), (11, 11, 11)" values_csv = "1,1,1\n1,1,1\n11,11,11\n" filename = "test.csv" query = "insert into table function s3('http://{}:{}/{}/{}', 'CSV', '{}') values {}".format( - cluster.minio_redirect_host, cluster.minio_redirect_port, cluster.minio_bucket, filename, table_format, values) + cluster.minio_redirect_host, cluster.minio_redirect_port, bucket, filename, table_format, values) run_query(instance, query) - assert values_csv == get_s3_file_content(cluster, filename) + assert values_csv == get_s3_file_content(cluster, bucket, filename) query = "select *, column1*column2*column3 from s3('http://{}:{}/{}/{}', 'CSV', '{}')".format( - cluster.minio_redirect_host, cluster.minio_redirect_port, cluster.minio_bucket, filename, table_format) + cluster.minio_redirect_host, cluster.minio_redirect_port, bucket, filename, table_format) stdout = run_query(instance, query) assert list(map(str.split, stdout.splitlines())) == [ @@ -158,9 +190,15 @@ def test_put_get_with_redirect(cluster): # Test multipart put. -def test_multipart_put(cluster): +@pytest.mark.parametrize("maybe_auth,positive", [ + ("",True), + ("'minio','minio123',",True), + ("'wrongid','wrongkey',",False) +]) +def test_multipart_put(cluster, maybe_auth, positive): # type: (ClickHouseCluster) -> None + bucket = cluster.minio_bucket if not maybe_auth else cluster.minio_restricted_bucket instance = cluster.instances["dummy"] # type: ClickHouseInstance table_format = "column1 UInt32, column2 UInt32, column3 UInt32" @@ -178,14 +216,19 @@ def test_multipart_put(cluster): assert len(csv_data) > min_part_size_bytes filename = "test_multipart.csv" - put_query = "insert into table function s3('http://{}:{}/{}/{}', 'CSV', '{}') format CSV".format( - cluster.minio_redirect_host, cluster.minio_redirect_port, cluster.minio_bucket, filename, table_format) + put_query = "insert into table function s3('http://{}:{}/{}/{}', {}'CSV', '{}') format CSV".format( + cluster.minio_redirect_host, cluster.minio_redirect_port, bucket, filename, maybe_auth, table_format) - run_query(instance, put_query, stdin=csv_data, settings={'s3_min_upload_part_size': min_part_size_bytes}) + try: + run_query(instance, put_query, stdin=csv_data, settings={'s3_min_upload_part_size': min_part_size_bytes}) + except helpers.client.QueryRuntimeException: + assert not positive + else: + assert positive - # Use Nginx access logs to count number of parts uploaded to Minio. - nginx_logs = get_nginx_access_logs() - uploaded_parts = filter(lambda log_line: log_line.find(filename) >= 0 and log_line.find("PUT") >= 0, nginx_logs) - assert uploaded_parts > 1 + # Use Nginx access logs to count number of parts uploaded to Minio. + nginx_logs = get_nginx_access_logs() + uploaded_parts = filter(lambda log_line: log_line.find(filename) >= 0 and log_line.find("PUT") >= 0, nginx_logs) + assert uploaded_parts > 1 - assert csv_data == get_s3_file_content(cluster, filename) + assert csv_data == get_s3_file_content(cluster, bucket, filename) diff --git a/dbms/tests/queries/0_stateless/01030_storage_s3_syntax.sql b/dbms/tests/queries/0_stateless/01030_storage_s3_syntax.sql index 6579984f57d..44cd149dd51 100644 --- a/dbms/tests/queries/0_stateless/01030_storage_s3_syntax.sql +++ b/dbms/tests/queries/0_stateless/01030_storage_s3_syntax.sql @@ -2,7 +2,7 @@ drop table if exists test_table_s3_syntax ; create table test_table_s3_syntax (id UInt32) ENGINE = S3('') ; -- { serverError 42 } -create table test_table_s3_syntax (id UInt32) ENGINE = S3('','','','') +create table test_table_s3_syntax (id UInt32) ENGINE = S3('','','','','','') ; -- { serverError 42 } drop table if exists test_table_s3_syntax ;