mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 07:31:57 +00:00
Merge pull request #5596 from excitoon-favorites/table_function_s3
s3 table function and storage
This commit is contained in:
commit
2054f80623
@ -45,7 +45,7 @@ namespace ErrorCodes
|
||||
|
||||
namespace
|
||||
{
|
||||
void setTimeouts(Poco::Net::HTTPClientSession & session, const ConnectionTimeouts & timeouts)
|
||||
void setTimeouts(Poco::Net::HTTPClientSession & session, const ConnectionTimeouts & timeouts)
|
||||
{
|
||||
#if defined(POCO_CLICKHOUSE_PATCH) || POCO_VERSION >= 0x02000000
|
||||
session.setTimeout(timeouts.connection_timeout, timeouts.send_timeout, timeouts.receive_timeout);
|
||||
@ -220,20 +220,25 @@ PooledHTTPSessionPtr makePooledHTTPSession(const Poco::URI & uri, const Connecti
|
||||
std::istream * receiveResponse(
|
||||
Poco::Net::HTTPClientSession & session, const Poco::Net::HTTPRequest & request, Poco::Net::HTTPResponse & response)
|
||||
{
|
||||
auto istr = &session.receiveResponse(response);
|
||||
auto & istr = session.receiveResponse(response);
|
||||
assertResponseIsOk(request, response, istr);
|
||||
return &istr;
|
||||
}
|
||||
|
||||
void assertResponseIsOk(const Poco::Net::HTTPRequest & request, Poco::Net::HTTPResponse & response, std::istream & istr)
|
||||
{
|
||||
auto status = response.getStatus();
|
||||
|
||||
if (status != Poco::Net::HTTPResponse::HTTP_OK)
|
||||
{
|
||||
std::stringstream error_message;
|
||||
error_message << "Received error from remote server " << request.getURI() << ". HTTP status code: " << status << " "
|
||||
<< response.getReason() << ", body: " << istr->rdbuf();
|
||||
<< response.getReason() << ", body: " << istr.rdbuf();
|
||||
|
||||
throw Exception(error_message.str(),
|
||||
status == HTTP_TOO_MANY_REQUESTS ? ErrorCodes::RECEIVED_ERROR_TOO_MANY_REQUESTS
|
||||
: ErrorCodes::RECEIVED_ERROR_FROM_REMOTE_IO_SERVER);
|
||||
}
|
||||
return istr;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -57,4 +57,6 @@ PooledHTTPSessionPtr makePooledHTTPSession(const Poco::URI & uri, const Connecti
|
||||
*/
|
||||
std::istream * receiveResponse(
|
||||
Poco::Net::HTTPClientSession & session, const Poco::Net::HTTPRequest & request, Poco::Net::HTTPResponse & response);
|
||||
void assertResponseIsOk(const Poco::Net::HTTPRequest & request, Poco::Net::HTTPResponse & response, std::istream & istr);
|
||||
|
||||
}
|
||||
|
70
dbms/src/IO/ReadBufferFromS3.cpp
Normal file
70
dbms/src/IO/ReadBufferFromS3.cpp
Normal file
@ -0,0 +1,70 @@
|
||||
#include <IO/ReadBufferFromS3.h>
|
||||
|
||||
#include <IO/ReadBufferFromIStream.h>
|
||||
|
||||
#include <common/logger_useful.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
const int DEFAULT_S3_MAX_FOLLOW_GET_REDIRECT = 2;
|
||||
|
||||
ReadBufferFromS3::ReadBufferFromS3(Poco::URI uri_,
|
||||
const ConnectionTimeouts & timeouts,
|
||||
const Poco::Net::HTTPBasicCredentials & credentials,
|
||||
size_t buffer_size_)
|
||||
: ReadBuffer(nullptr, 0)
|
||||
, uri {uri_}
|
||||
, method {Poco::Net::HTTPRequest::HTTP_GET}
|
||||
, session {makeHTTPSession(uri_, timeouts)}
|
||||
{
|
||||
Poco::Net::HTTPResponse response;
|
||||
std::unique_ptr<Poco::Net::HTTPRequest> request;
|
||||
|
||||
for (int i = 0; i < DEFAULT_S3_MAX_FOLLOW_GET_REDIRECT; ++i)
|
||||
{
|
||||
// With empty path poco will send "POST HTTP/1.1" its bug.
|
||||
if (uri.getPath().empty())
|
||||
uri.setPath("/");
|
||||
|
||||
request = std::make_unique<Poco::Net::HTTPRequest>(method, uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1);
|
||||
request->setHost(uri.getHost()); // use original, not resolved host name in header
|
||||
|
||||
if (!credentials.getUsername().empty())
|
||||
credentials.authenticate(*request);
|
||||
|
||||
LOG_TRACE((&Logger::get("ReadBufferFromS3")), "Sending request to " << uri.toString());
|
||||
|
||||
session->sendRequest(*request);
|
||||
|
||||
istr = &session->receiveResponse(response);
|
||||
|
||||
// Handle 307 Temporary Redirect in order to allow request redirection
|
||||
// See https://docs.aws.amazon.com/AmazonS3/latest/dev/Redirects.html
|
||||
if (response.getStatus() != Poco::Net::HTTPResponse::HTTP_TEMPORARY_REDIRECT)
|
||||
break;
|
||||
|
||||
auto location_iterator = response.find("Location");
|
||||
if (location_iterator == response.end())
|
||||
break;
|
||||
|
||||
uri = location_iterator->second;
|
||||
session = makeHTTPSession(uri, timeouts);
|
||||
}
|
||||
|
||||
assertResponseIsOk(*request, response, *istr);
|
||||
impl = std::make_unique<ReadBufferFromIStream>(*istr, buffer_size_);
|
||||
}
|
||||
|
||||
|
||||
bool ReadBufferFromS3::nextImpl()
|
||||
{
|
||||
if (!impl->next())
|
||||
return false;
|
||||
internal_buffer = impl->buffer();
|
||||
working_buffer = internal_buffer;
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
35
dbms/src/IO/ReadBufferFromS3.h
Normal file
35
dbms/src/IO/ReadBufferFromS3.h
Normal file
@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <IO/ConnectionTimeouts.h>
|
||||
#include <IO/HTTPCommon.h>
|
||||
#include <IO/ReadBuffer.h>
|
||||
#include <Poco/Net/HTTPBasicCredentials.h>
|
||||
#include <Poco/URI.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
/** Perform S3 HTTP GET request and provide response to read.
|
||||
*/
|
||||
class ReadBufferFromS3 : public ReadBuffer
|
||||
{
|
||||
protected:
|
||||
Poco::URI uri;
|
||||
std::string method;
|
||||
|
||||
HTTPSessionPtr session;
|
||||
std::istream * istr; /// owned by session
|
||||
std::unique_ptr<ReadBuffer> impl;
|
||||
|
||||
public:
|
||||
explicit ReadBufferFromS3(Poco::URI uri_,
|
||||
const ConnectionTimeouts & timeouts = {},
|
||||
const Poco::Net::HTTPBasicCredentials & credentials = {},
|
||||
size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE);
|
||||
|
||||
bool nextImpl() override;
|
||||
};
|
||||
|
||||
}
|
286
dbms/src/IO/WriteBufferFromS3.cpp
Normal file
286
dbms/src/IO/WriteBufferFromS3.cpp
Normal file
@ -0,0 +1,286 @@
|
||||
#include <IO/WriteBufferFromS3.h>
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
#include <Poco/DOM/AutoPtr.h>
|
||||
#include <Poco/DOM/DOMParser.h>
|
||||
#include <Poco/DOM/Document.h>
|
||||
#include <Poco/DOM/NodeList.h>
|
||||
#include <Poco/SAX/InputSource.h>
|
||||
|
||||
#include <common/logger_useful.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
const int DEFAULT_S3_MAX_FOLLOW_PUT_REDIRECT = 2;
|
||||
const int S3_WARN_MAX_PARTS = 10000;
|
||||
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int INCORRECT_DATA;
|
||||
}
|
||||
|
||||
|
||||
WriteBufferFromS3::WriteBufferFromS3(
|
||||
const Poco::URI & uri_,
|
||||
size_t minimum_upload_part_size_,
|
||||
const ConnectionTimeouts & timeouts_,
|
||||
const Poco::Net::HTTPBasicCredentials & credentials, size_t buffer_size_
|
||||
)
|
||||
: BufferWithOwnMemory<WriteBuffer>(buffer_size_, nullptr, 0)
|
||||
, uri {uri_}
|
||||
, minimum_upload_part_size {minimum_upload_part_size_}
|
||||
, timeouts {timeouts_}
|
||||
, auth_request {Poco::Net::HTTPRequest::HTTP_PUT, uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1}
|
||||
, temporary_buffer {std::make_unique<WriteBufferFromString>(buffer_string)}
|
||||
, last_part_size {0}
|
||||
{
|
||||
if (!credentials.getUsername().empty())
|
||||
credentials.authenticate(auth_request);
|
||||
|
||||
initiate();
|
||||
}
|
||||
|
||||
|
||||
void WriteBufferFromS3::nextImpl()
|
||||
{
|
||||
if (!offset())
|
||||
return;
|
||||
|
||||
temporary_buffer->write(working_buffer.begin(), offset());
|
||||
|
||||
last_part_size += offset();
|
||||
|
||||
if (last_part_size > minimum_upload_part_size)
|
||||
{
|
||||
temporary_buffer->finish();
|
||||
writePart(buffer_string);
|
||||
last_part_size = 0;
|
||||
temporary_buffer = std::make_unique<WriteBufferFromString>(buffer_string);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void WriteBufferFromS3::finalize()
|
||||
{
|
||||
temporary_buffer->finish();
|
||||
if (!buffer_string.empty())
|
||||
{
|
||||
writePart(buffer_string);
|
||||
}
|
||||
|
||||
complete();
|
||||
}
|
||||
|
||||
|
||||
WriteBufferFromS3::~WriteBufferFromS3()
|
||||
{
|
||||
try
|
||||
{
|
||||
next();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
tryLogCurrentException(__PRETTY_FUNCTION__);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void WriteBufferFromS3::initiate()
|
||||
{
|
||||
// See https://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadInitiate.html
|
||||
Poco::Net::HTTPResponse response;
|
||||
std::unique_ptr<Poco::Net::HTTPRequest> request_ptr;
|
||||
HTTPSessionPtr session;
|
||||
std::istream * istr = nullptr; /// owned by session
|
||||
Poco::URI initiate_uri = uri;
|
||||
initiate_uri.setRawQuery("uploads");
|
||||
for (auto & param: uri.getQueryParameters())
|
||||
{
|
||||
initiate_uri.addQueryParameter(param.first, param.second);
|
||||
}
|
||||
|
||||
for (int i = 0; i < DEFAULT_S3_MAX_FOLLOW_PUT_REDIRECT; ++i)
|
||||
{
|
||||
session = makeHTTPSession(initiate_uri, timeouts);
|
||||
request_ptr = std::make_unique<Poco::Net::HTTPRequest>(Poco::Net::HTTPRequest::HTTP_POST, initiate_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1);
|
||||
request_ptr->setHost(initiate_uri.getHost()); // use original, not resolved host name in header
|
||||
|
||||
if (auth_request.hasCredentials())
|
||||
{
|
||||
Poco::Net::HTTPBasicCredentials credentials(auth_request);
|
||||
credentials.authenticate(*request_ptr);
|
||||
}
|
||||
|
||||
request_ptr->setContentLength(0);
|
||||
|
||||
LOG_TRACE((&Logger::get("WriteBufferFromS3")), "Sending request to " << initiate_uri.toString());
|
||||
|
||||
session->sendRequest(*request_ptr);
|
||||
|
||||
istr = &session->receiveResponse(response);
|
||||
|
||||
// Handle 307 Temporary Redirect in order to allow request redirection
|
||||
// See https://docs.aws.amazon.com/AmazonS3/latest/dev/Redirects.html
|
||||
if (response.getStatus() != Poco::Net::HTTPResponse::HTTP_TEMPORARY_REDIRECT)
|
||||
break;
|
||||
|
||||
auto location_iterator = response.find("Location");
|
||||
if (location_iterator == response.end())
|
||||
break;
|
||||
|
||||
initiate_uri = location_iterator->second;
|
||||
}
|
||||
assertResponseIsOk(*request_ptr, response, *istr);
|
||||
|
||||
Poco::XML::InputSource src(*istr);
|
||||
Poco::XML::DOMParser parser;
|
||||
Poco::AutoPtr<Poco::XML::Document> document = parser.parse(&src);
|
||||
Poco::AutoPtr<Poco::XML::NodeList> nodes = document->getElementsByTagName("UploadId");
|
||||
if (nodes->length() != 1)
|
||||
{
|
||||
throw Exception("Incorrect XML in response, no upload id", ErrorCodes::INCORRECT_DATA);
|
||||
}
|
||||
upload_id = nodes->item(0)->innerText();
|
||||
if (upload_id.empty())
|
||||
{
|
||||
throw Exception("Incorrect XML in response, empty upload id", ErrorCodes::INCORRECT_DATA);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void WriteBufferFromS3::writePart(const String & data)
|
||||
{
|
||||
// See https://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadUploadPart.html
|
||||
Poco::Net::HTTPResponse response;
|
||||
std::unique_ptr<Poco::Net::HTTPRequest> request_ptr;
|
||||
HTTPSessionPtr session;
|
||||
std::istream * istr = nullptr; /// owned by session
|
||||
Poco::URI part_uri = uri;
|
||||
part_uri.addQueryParameter("partNumber", std::to_string(part_tags.size() + 1));
|
||||
part_uri.addQueryParameter("uploadId", upload_id);
|
||||
|
||||
if (part_tags.size() == S3_WARN_MAX_PARTS)
|
||||
{
|
||||
// Don't throw exception here by ourselves but leave the decision to take by S3 server.
|
||||
LOG_WARNING(&Logger::get("WriteBufferFromS3"), "Maximum part number in S3 protocol has reached (too much parts). Server may not accept this whole upload.");
|
||||
}
|
||||
|
||||
for (int i = 0; i < DEFAULT_S3_MAX_FOLLOW_PUT_REDIRECT; ++i)
|
||||
{
|
||||
session = makeHTTPSession(part_uri, timeouts);
|
||||
request_ptr = std::make_unique<Poco::Net::HTTPRequest>(Poco::Net::HTTPRequest::HTTP_PUT, part_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1);
|
||||
request_ptr->setHost(part_uri.getHost()); // use original, not resolved host name in header
|
||||
|
||||
if (auth_request.hasCredentials())
|
||||
{
|
||||
Poco::Net::HTTPBasicCredentials credentials(auth_request);
|
||||
credentials.authenticate(*request_ptr);
|
||||
}
|
||||
|
||||
request_ptr->setExpectContinue(true);
|
||||
|
||||
request_ptr->setContentLength(data.size());
|
||||
|
||||
LOG_TRACE((&Logger::get("WriteBufferFromS3")), "Sending request to " << part_uri.toString());
|
||||
|
||||
std::ostream & ostr = session->sendRequest(*request_ptr);
|
||||
if (session->peekResponse(response))
|
||||
{
|
||||
// Received 100-continue.
|
||||
ostr << data;
|
||||
}
|
||||
|
||||
istr = &session->receiveResponse(response);
|
||||
|
||||
// Handle 307 Temporary Redirect in order to allow request redirection
|
||||
// See https://docs.aws.amazon.com/AmazonS3/latest/dev/Redirects.html
|
||||
if (response.getStatus() != Poco::Net::HTTPResponse::HTTP_TEMPORARY_REDIRECT)
|
||||
break;
|
||||
|
||||
auto location_iterator = response.find("Location");
|
||||
if (location_iterator == response.end())
|
||||
break;
|
||||
|
||||
part_uri = location_iterator->second;
|
||||
}
|
||||
assertResponseIsOk(*request_ptr, response, *istr);
|
||||
|
||||
auto etag_iterator = response.find("ETag");
|
||||
if (etag_iterator == response.end())
|
||||
{
|
||||
throw Exception("Incorrect response, no ETag", ErrorCodes::INCORRECT_DATA);
|
||||
}
|
||||
part_tags.push_back(etag_iterator->second);
|
||||
}
|
||||
|
||||
|
||||
void WriteBufferFromS3::complete()
|
||||
{
|
||||
// See https://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadComplete.html
|
||||
Poco::Net::HTTPResponse response;
|
||||
std::unique_ptr<Poco::Net::HTTPRequest> request_ptr;
|
||||
HTTPSessionPtr session;
|
||||
std::istream * istr = nullptr; /// owned by session
|
||||
Poco::URI complete_uri = uri;
|
||||
complete_uri.addQueryParameter("uploadId", upload_id);
|
||||
|
||||
String data;
|
||||
WriteBufferFromString buffer(data);
|
||||
writeString("<CompleteMultipartUpload>", buffer);
|
||||
for (size_t i = 0; i < part_tags.size(); ++i)
|
||||
{
|
||||
writeString("<Part><PartNumber>", buffer);
|
||||
writeIntText(i + 1, buffer);
|
||||
writeString("</PartNumber><ETag>", buffer);
|
||||
writeString(part_tags[i], buffer);
|
||||
writeString("</ETag></Part>", buffer);
|
||||
}
|
||||
writeString("</CompleteMultipartUpload>", buffer);
|
||||
buffer.finish();
|
||||
|
||||
for (int i = 0; i < DEFAULT_S3_MAX_FOLLOW_PUT_REDIRECT; ++i)
|
||||
{
|
||||
session = makeHTTPSession(complete_uri, timeouts);
|
||||
request_ptr = std::make_unique<Poco::Net::HTTPRequest>(Poco::Net::HTTPRequest::HTTP_POST, complete_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1);
|
||||
request_ptr->setHost(complete_uri.getHost()); // use original, not resolved host name in header
|
||||
|
||||
if (auth_request.hasCredentials())
|
||||
{
|
||||
Poco::Net::HTTPBasicCredentials credentials(auth_request);
|
||||
credentials.authenticate(*request_ptr);
|
||||
}
|
||||
|
||||
request_ptr->setExpectContinue(true);
|
||||
|
||||
request_ptr->setContentLength(data.size());
|
||||
|
||||
LOG_TRACE((&Logger::get("WriteBufferFromS3")), "Sending request to " << complete_uri.toString());
|
||||
|
||||
std::ostream & ostr = session->sendRequest(*request_ptr);
|
||||
if (session->peekResponse(response))
|
||||
{
|
||||
// Received 100-continue.
|
||||
ostr << data;
|
||||
}
|
||||
|
||||
istr = &session->receiveResponse(response);
|
||||
|
||||
// Handle 307 Temporary Redirect in order to allow request redirection
|
||||
// See https://docs.aws.amazon.com/AmazonS3/latest/dev/Redirects.html
|
||||
if (response.getStatus() != Poco::Net::HTTPResponse::HTTP_TEMPORARY_REDIRECT)
|
||||
break;
|
||||
|
||||
auto location_iterator = response.find("Location");
|
||||
if (location_iterator == response.end())
|
||||
break;
|
||||
|
||||
complete_uri = location_iterator->second;
|
||||
}
|
||||
assertResponseIsOk(*request_ptr, response, *istr);
|
||||
}
|
||||
|
||||
}
|
62
dbms/src/IO/WriteBufferFromS3.h
Normal file
62
dbms/src/IO/WriteBufferFromS3.h
Normal file
@ -0,0 +1,62 @@
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <Core/Types.h>
|
||||
#include <IO/ConnectionTimeouts.h>
|
||||
#include <IO/HTTPCommon.h>
|
||||
#include <IO/BufferWithOwnMemory.h>
|
||||
#include <IO/ReadBuffer.h>
|
||||
#include <IO/ReadBufferFromIStream.h>
|
||||
#include <IO/WriteBuffer.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <Poco/Net/HTTPBasicCredentials.h>
|
||||
#include <Poco/Net/HTTPClientSession.h>
|
||||
#include <Poco/Net/HTTPRequest.h>
|
||||
#include <Poco/Net/HTTPResponse.h>
|
||||
#include <Poco/URI.h>
|
||||
#include <Poco/Version.h>
|
||||
#include <Common/DNSResolver.h>
|
||||
#include <Common/config.h>
|
||||
#include <common/logger_useful.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
/* Perform S3 HTTP PUT request.
|
||||
*/
|
||||
class WriteBufferFromS3 : public BufferWithOwnMemory<WriteBuffer>
|
||||
{
|
||||
private:
|
||||
Poco::URI uri;
|
||||
size_t minimum_upload_part_size;
|
||||
ConnectionTimeouts timeouts;
|
||||
Poco::Net::HTTPRequest auth_request;
|
||||
String buffer_string;
|
||||
std::unique_ptr<WriteBufferFromString> temporary_buffer;
|
||||
size_t last_part_size;
|
||||
String upload_id;
|
||||
std::vector<String> part_tags;
|
||||
|
||||
public:
|
||||
explicit WriteBufferFromS3(const Poco::URI & uri,
|
||||
size_t minimum_upload_part_size_,
|
||||
const ConnectionTimeouts & timeouts = {},
|
||||
const Poco::Net::HTTPBasicCredentials & credentials = {},
|
||||
size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE);
|
||||
|
||||
void nextImpl() override;
|
||||
|
||||
/// Receives response from the server after sending all data.
|
||||
void finalize();
|
||||
|
||||
~WriteBufferFromS3() override;
|
||||
|
||||
private:
|
||||
void initiate();
|
||||
void writePart(const String & data);
|
||||
void complete();
|
||||
};
|
||||
|
||||
}
|
177
dbms/src/Storages/StorageS3.cpp
Normal file
177
dbms/src/Storages/StorageS3.cpp
Normal file
@ -0,0 +1,177 @@
|
||||
#include <Storages/StorageFactory.h>
|
||||
#include <Storages/StorageS3.h>
|
||||
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Interpreters/evaluateConstantExpression.h>
|
||||
#include <Parsers/ASTLiteral.h>
|
||||
|
||||
#include <IO/ReadBufferFromS3.h>
|
||||
#include <IO/WriteBufferFromS3.h>
|
||||
|
||||
#include <Formats/FormatFactory.h>
|
||||
|
||||
#include <DataStreams/IBlockOutputStream.h>
|
||||
#include <DataStreams/IBlockInputStream.h>
|
||||
#include <DataStreams/AddingDefaultsBlockInputStream.h>
|
||||
|
||||
#include <Poco/Net/HTTPRequest.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
class StorageS3BlockInputStream : public IBlockInputStream
|
||||
{
|
||||
public:
|
||||
StorageS3BlockInputStream(const Poco::URI & uri,
|
||||
const String & format,
|
||||
const String & name_,
|
||||
const Block & sample_block,
|
||||
const Context & context,
|
||||
UInt64 max_block_size,
|
||||
const ConnectionTimeouts & timeouts)
|
||||
: name(name_)
|
||||
{
|
||||
read_buf = std::make_unique<ReadBufferFromS3>(uri, timeouts);
|
||||
|
||||
reader = FormatFactory::instance().getInput(format, *read_buf, sample_block, context, max_block_size);
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
Block readImpl() override
|
||||
{
|
||||
return reader->read();
|
||||
}
|
||||
|
||||
Block getHeader() const override
|
||||
{
|
||||
return reader->getHeader();
|
||||
}
|
||||
|
||||
void readPrefixImpl() override
|
||||
{
|
||||
reader->readPrefix();
|
||||
}
|
||||
|
||||
void readSuffixImpl() override
|
||||
{
|
||||
reader->readSuffix();
|
||||
}
|
||||
|
||||
private:
|
||||
String name;
|
||||
std::unique_ptr<ReadBufferFromS3> read_buf;
|
||||
BlockInputStreamPtr reader;
|
||||
};
|
||||
|
||||
class StorageS3BlockOutputStream : public IBlockOutputStream
|
||||
{
|
||||
public:
|
||||
StorageS3BlockOutputStream(const Poco::URI & uri,
|
||||
const String & format,
|
||||
const Block & sample_block_,
|
||||
const Context & context,
|
||||
const ConnectionTimeouts & timeouts)
|
||||
: sample_block(sample_block_)
|
||||
{
|
||||
auto minimum_upload_part_size = context.getConfigRef().getUInt64("s3_minimum_upload_part_size", 512'000'000);
|
||||
write_buf = std::make_unique<WriteBufferFromS3>(uri, minimum_upload_part_size, timeouts);
|
||||
writer = FormatFactory::instance().getOutput(format, *write_buf, sample_block, context);
|
||||
}
|
||||
|
||||
Block getHeader() const override
|
||||
{
|
||||
return sample_block;
|
||||
}
|
||||
|
||||
void write(const Block & block) override
|
||||
{
|
||||
writer->write(block);
|
||||
}
|
||||
|
||||
void writePrefix() override
|
||||
{
|
||||
writer->writePrefix();
|
||||
}
|
||||
|
||||
void writeSuffix() override
|
||||
{
|
||||
writer->writeSuffix();
|
||||
writer->flush();
|
||||
write_buf->finalize();
|
||||
}
|
||||
|
||||
private:
|
||||
Block sample_block;
|
||||
std::unique_ptr<WriteBufferFromS3> write_buf;
|
||||
BlockOutputStreamPtr writer;
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
BlockInputStreams StorageS3::read(const Names & column_names,
|
||||
const SelectQueryInfo & /*query_info*/,
|
||||
const Context & context,
|
||||
QueryProcessingStage::Enum /*processed_stage*/,
|
||||
size_t max_block_size,
|
||||
unsigned /*num_streams*/)
|
||||
{
|
||||
BlockInputStreamPtr block_input = std::make_shared<StorageS3BlockInputStream>(uri,
|
||||
format_name,
|
||||
getName(),
|
||||
getHeaderBlock(column_names),
|
||||
context,
|
||||
max_block_size,
|
||||
ConnectionTimeouts::getHTTPTimeouts(context));
|
||||
|
||||
auto column_defaults = getColumns().getDefaults();
|
||||
if (column_defaults.empty())
|
||||
return {block_input};
|
||||
return {std::make_shared<AddingDefaultsBlockInputStream>(block_input, column_defaults, context)};
|
||||
}
|
||||
|
||||
void StorageS3::rename(const String & /*new_path_to_db*/, const String & new_database_name, const String & new_table_name, TableStructureWriteLockHolder &)
|
||||
{
|
||||
table_name = new_table_name;
|
||||
database_name = new_database_name;
|
||||
}
|
||||
|
||||
BlockOutputStreamPtr StorageS3::write(const ASTPtr & /*query*/, const Context & /*context*/)
|
||||
{
|
||||
return std::make_shared<StorageS3BlockOutputStream>(
|
||||
uri, format_name, getSampleBlock(), context_global, ConnectionTimeouts::getHTTPTimeouts(context_global));
|
||||
}
|
||||
|
||||
void registerStorageS3(StorageFactory & factory)
|
||||
{
|
||||
factory.registerStorage("S3", [](const StorageFactory::Arguments & args)
|
||||
{
|
||||
ASTs & engine_args = args.engine_args;
|
||||
|
||||
if (!(engine_args.size() == 1 || engine_args.size() == 2))
|
||||
throw Exception(
|
||||
"Storage S3 requires exactly 2 arguments: url and name of used format.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
engine_args[0] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[0], args.local_context);
|
||||
|
||||
String url = engine_args[0]->as<ASTLiteral &>().value.safeGet<String>();
|
||||
Poco::URI uri(url);
|
||||
|
||||
engine_args[1] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[1], args.local_context);
|
||||
|
||||
String format_name = engine_args[1]->as<ASTLiteral &>().value.safeGet<String>();
|
||||
|
||||
return StorageS3::create(uri, args.database_name, args.table_name, format_name, args.columns, args.context);
|
||||
});
|
||||
}
|
||||
}
|
71
dbms/src/Storages/StorageS3.h
Normal file
71
dbms/src/Storages/StorageS3.h
Normal file
@ -0,0 +1,71 @@
|
||||
#pragma once
|
||||
|
||||
#include <Storages/IStorage.h>
|
||||
#include <Poco/URI.h>
|
||||
#include <common/logger_useful.h>
|
||||
#include <ext/shared_ptr_helper.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
/**
|
||||
* This class represents table engine for external S3 urls.
|
||||
* It sends HTTP GET to server when select is called and
|
||||
* HTTP PUT when insert is called.
|
||||
*/
|
||||
class StorageS3 : public ext::shared_ptr_helper<StorageS3>, public IStorage
|
||||
{
|
||||
public:
|
||||
StorageS3(const Poco::URI & uri_,
|
||||
const std::string & database_name_,
|
||||
const std::string & table_name_,
|
||||
const String & format_name_,
|
||||
const ColumnsDescription & columns_,
|
||||
Context & context_
|
||||
)
|
||||
: IStorage(columns_)
|
||||
, uri(uri_)
|
||||
, context_global(context_)
|
||||
, format_name(format_name_)
|
||||
, database_name(database_name_)
|
||||
, table_name(table_name_)
|
||||
{
|
||||
setColumns(columns_);
|
||||
}
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "S3";
|
||||
}
|
||||
|
||||
Block getHeaderBlock(const Names & /*column_names*/) const
|
||||
{
|
||||
return getSampleBlock();
|
||||
}
|
||||
|
||||
String getTableName() const override
|
||||
{
|
||||
return table_name;
|
||||
}
|
||||
|
||||
BlockInputStreams read(const Names & column_names,
|
||||
const SelectQueryInfo & query_info,
|
||||
const Context & context,
|
||||
QueryProcessingStage::Enum processed_stage,
|
||||
size_t max_block_size,
|
||||
unsigned num_streams) override;
|
||||
|
||||
BlockOutputStreamPtr write(const ASTPtr & query, const Context & context) override;
|
||||
|
||||
void rename(const String & new_path_to_db, const String & new_database_name, const String & new_table_name, TableStructureWriteLockHolder &) override;
|
||||
|
||||
protected:
|
||||
Poco::URI uri;
|
||||
const Context & context_global;
|
||||
|
||||
private:
|
||||
String format_name;
|
||||
String database_name;
|
||||
String table_name;
|
||||
};
|
||||
|
||||
}
|
@ -19,6 +19,7 @@ void registerStorageDistributed(StorageFactory & factory);
|
||||
void registerStorageMemory(StorageFactory & factory);
|
||||
void registerStorageFile(StorageFactory & factory);
|
||||
void registerStorageURL(StorageFactory & factory);
|
||||
void registerStorageS3(StorageFactory & factory);
|
||||
void registerStorageDictionary(StorageFactory & factory);
|
||||
void registerStorageSet(StorageFactory & factory);
|
||||
void registerStorageJoin(StorageFactory & factory);
|
||||
@ -60,6 +61,7 @@ void registerStorages()
|
||||
registerStorageMemory(factory);
|
||||
registerStorageFile(factory);
|
||||
registerStorageURL(factory);
|
||||
registerStorageS3(factory);
|
||||
registerStorageDictionary(factory);
|
||||
registerStorageSet(factory);
|
||||
registerStorageJoin(factory);
|
||||
|
19
dbms/src/TableFunctions/TableFunctionS3.cpp
Normal file
19
dbms/src/TableFunctions/TableFunctionS3.cpp
Normal file
@ -0,0 +1,19 @@
|
||||
#include <Storages/StorageS3.h>
|
||||
#include <TableFunctions/TableFunctionFactory.h>
|
||||
#include <TableFunctions/TableFunctionS3.h>
|
||||
#include <Poco/URI.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
StoragePtr TableFunctionS3::getStorage(
|
||||
const String & source, const String & format, const ColumnsDescription & columns, Context & global_context, const std::string & table_name) const
|
||||
{
|
||||
Poco::URI uri(source);
|
||||
return StorageS3::create(uri, getDatabaseName(), table_name, format, columns, global_context);
|
||||
}
|
||||
|
||||
void registerTableFunctionS3(TableFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction<TableFunctionS3>();
|
||||
}
|
||||
}
|
25
dbms/src/TableFunctions/TableFunctionS3.h
Normal file
25
dbms/src/TableFunctions/TableFunctionS3.h
Normal file
@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <TableFunctions/ITableFunctionFileLike.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Core/Block.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
/* s3(source, format, structure) - creates a temporary storage for a file in S3
|
||||
*/
|
||||
class TableFunctionS3 : public ITableFunctionFileLike
|
||||
{
|
||||
public:
|
||||
static constexpr auto name = "s3";
|
||||
std::string getName() const override
|
||||
{
|
||||
return name;
|
||||
}
|
||||
|
||||
private:
|
||||
StoragePtr getStorage(
|
||||
const String & source, const String & format, const ColumnsDescription & columns, Context & global_context, const std::string & table_name) const override;
|
||||
};
|
||||
}
|
@ -11,6 +11,7 @@ void registerTableFunctionMerge(TableFunctionFactory & factory);
|
||||
void registerTableFunctionRemote(TableFunctionFactory & factory);
|
||||
void registerTableFunctionNumbers(TableFunctionFactory & factory);
|
||||
void registerTableFunctionFile(TableFunctionFactory & factory);
|
||||
void registerTableFunctionS3(TableFunctionFactory & factory);
|
||||
void registerTableFunctionURL(TableFunctionFactory & factory);
|
||||
void registerTableFunctionValues(TableFunctionFactory & factory);
|
||||
void registerTableFunctionInput(TableFunctionFactory & factory);
|
||||
@ -38,6 +39,7 @@ void registerTableFunctions()
|
||||
registerTableFunctionRemote(factory);
|
||||
registerTableFunctionNumbers(factory);
|
||||
registerTableFunctionFile(factory);
|
||||
registerTableFunctionS3(factory);
|
||||
registerTableFunctionURL(factory);
|
||||
registerTableFunctionValues(factory);
|
||||
registerTableFunctionInput(factory);
|
||||
|
@ -225,12 +225,12 @@ class ClickHouseCluster:
|
||||
def restart_instance_with_ip_change(self, node, new_ip):
|
||||
if '::' in new_ip:
|
||||
if node.ipv6_address is None:
|
||||
raise Exception("You shoud specity ipv6_address in add_node method")
|
||||
raise Exception("You should specity ipv6_address in add_node method")
|
||||
self._replace(node.docker_compose_path, node.ipv6_address, new_ip)
|
||||
node.ipv6_address = new_ip
|
||||
else:
|
||||
if node.ipv4_address is None:
|
||||
raise Exception("You shoud specity ipv4_address in add_node method")
|
||||
raise Exception("You should specity ipv4_address in add_node method")
|
||||
self._replace(node.docker_compose_path, node.ipv4_address, new_ip)
|
||||
node.ipv4_address = new_ip
|
||||
subprocess.check_call(self.base_cmd + ["stop", node.name])
|
||||
|
@ -107,4 +107,4 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
#print(cmd)
|
||||
subprocess.check_call(cmd, shell=True)
|
||||
subprocess.check_call(cmd, shell=True)
|
||||
|
0
dbms/tests/integration/test_storage_s3/__init__.py
Normal file
0
dbms/tests/integration/test_storage_s3/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
<yandex>
|
||||
<s3_minimum_upload_part_size>1000000</s3_minimum_upload_part_size>
|
||||
</yandex>
|
159
dbms/tests/integration/test_storage_s3/test.py
Normal file
159
dbms/tests/integration/test_storage_s3/test.py
Normal file
@ -0,0 +1,159 @@
|
||||
import httplib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
|
||||
from helpers.cluster import ClickHouseCluster
|
||||
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
logging.getLogger().addHandler(logging.StreamHandler())
|
||||
|
||||
|
||||
def get_communication_data(started_cluster):
|
||||
conn = httplib.HTTPConnection(started_cluster.instances["dummy"].ip_address, started_cluster.communication_port)
|
||||
conn.request("GET", "/")
|
||||
r = conn.getresponse()
|
||||
raw_data = r.read()
|
||||
conn.close()
|
||||
return json.loads(raw_data)
|
||||
|
||||
|
||||
def put_communication_data(started_cluster, body):
|
||||
conn = httplib.HTTPConnection(started_cluster.instances["dummy"].ip_address, started_cluster.communication_port)
|
||||
conn.request("PUT", "/", body)
|
||||
r = conn.getresponse()
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def started_cluster():
|
||||
try:
|
||||
cluster = ClickHouseCluster(__file__)
|
||||
instance = cluster.add_instance("dummy", config_dir="configs", main_configs=["configs/min_chunk_size.xml"])
|
||||
cluster.start()
|
||||
|
||||
cluster.communication_port = 10000
|
||||
instance.copy_file_to_container(os.path.join(os.path.dirname(__file__), "test_server.py"), "test_server.py")
|
||||
cluster.bucket = "abc"
|
||||
instance.exec_in_container(["python", "test_server.py", str(cluster.communication_port), cluster.bucket], detach=True)
|
||||
cluster.mock_host = instance.ip_address
|
||||
|
||||
for i in range(10):
|
||||
try:
|
||||
data = get_communication_data(cluster)
|
||||
cluster.redirecting_to_http_port = data["redirecting_to_http_port"]
|
||||
cluster.preserving_data_port = data["preserving_data_port"]
|
||||
cluster.multipart_preserving_data_port = data["multipart_preserving_data_port"]
|
||||
cluster.redirecting_preserving_data_port = data["redirecting_preserving_data_port"]
|
||||
except:
|
||||
logging.error(traceback.format_exc())
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
assert False, "Could not initialize mock server"
|
||||
|
||||
yield cluster
|
||||
|
||||
finally:
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
def run_query(instance, query, stdin=None):
|
||||
logging.info("Running query '{}'...".format(query))
|
||||
result = instance.query(query, stdin=stdin)
|
||||
logging.info("Query finished")
|
||||
return result
|
||||
|
||||
|
||||
def test_get_with_redirect(started_cluster):
|
||||
instance = started_cluster.instances["dummy"]
|
||||
format = "column1 UInt32, column2 UInt32, column3 UInt32"
|
||||
|
||||
put_communication_data(started_cluster, "=== Get with redirect test ===")
|
||||
query = "select *, column1*column2*column3 from s3('http://{}:{}/', 'CSV', '{}')".format(started_cluster.mock_host, started_cluster.redirecting_to_http_port, format)
|
||||
stdout = run_query(instance, query)
|
||||
data = get_communication_data(started_cluster)
|
||||
expected = [ [str(row[0]), str(row[1]), str(row[2]), str(row[0]*row[1]*row[2])] for row in data["redirect_csv_data"] ]
|
||||
assert list(map(str.split, stdout.splitlines())) == expected
|
||||
|
||||
|
||||
def test_put(started_cluster):
|
||||
instance = started_cluster.instances["dummy"]
|
||||
format = "column1 UInt32, column2 UInt32, column3 UInt32"
|
||||
|
||||
logging.info("Phase 3")
|
||||
put_communication_data(started_cluster, "=== Put test ===")
|
||||
values = "(1, 2, 3), (3, 2, 1), (78, 43, 45)"
|
||||
put_query = "insert into table function s3('http://{}:{}/{}/test.csv', 'CSV', '{}') values {}".format(started_cluster.mock_host, started_cluster.preserving_data_port, started_cluster.bucket, format, values)
|
||||
run_query(instance, put_query)
|
||||
data = get_communication_data(started_cluster)
|
||||
received_data_completed = data["received_data_completed"]
|
||||
received_data = data["received_data"]
|
||||
finalize_data = data["finalize_data"]
|
||||
finalize_data_query = data["finalize_data_query"]
|
||||
assert received_data[-1].decode() == "1,2,3\n3,2,1\n78,43,45\n"
|
||||
assert received_data_completed
|
||||
assert finalize_data == "<CompleteMultipartUpload><Part><PartNumber>1</PartNumber><ETag>hello-etag</ETag></Part></CompleteMultipartUpload>"
|
||||
assert finalize_data_query == "uploadId=TEST"
|
||||
|
||||
|
||||
def test_put_csv(started_cluster):
|
||||
instance = started_cluster.instances["dummy"]
|
||||
format = "column1 UInt32, column2 UInt32, column3 UInt32"
|
||||
|
||||
put_communication_data(started_cluster, "=== Put test CSV ===")
|
||||
put_query = "insert into table function s3('http://{}:{}/{}/test.csv', 'CSV', '{}') format CSV".format(started_cluster.mock_host, started_cluster.preserving_data_port, started_cluster.bucket, format)
|
||||
csv_data = "8,9,16\n11,18,13\n22,14,2\n"
|
||||
run_query(instance, put_query, stdin=csv_data)
|
||||
data = get_communication_data(started_cluster)
|
||||
received_data_completed = data["received_data_completed"]
|
||||
received_data = data["received_data"]
|
||||
finalize_data = data["finalize_data"]
|
||||
finalize_data_query = data["finalize_data_query"]
|
||||
assert received_data[-1].decode() == csv_data
|
||||
assert received_data_completed
|
||||
assert finalize_data == "<CompleteMultipartUpload><Part><PartNumber>1</PartNumber><ETag>hello-etag</ETag></Part></CompleteMultipartUpload>"
|
||||
assert finalize_data_query == "uploadId=TEST"
|
||||
|
||||
|
||||
def test_put_with_redirect(started_cluster):
|
||||
instance = started_cluster.instances["dummy"]
|
||||
format = "column1 UInt32, column2 UInt32, column3 UInt32"
|
||||
|
||||
put_communication_data(started_cluster, "=== Put with redirect test ===")
|
||||
other_values = "(1, 1, 1), (1, 1, 1), (11, 11, 11)"
|
||||
query = "insert into table function s3('http://{}:{}/{}/test.csv', 'CSV', '{}') values {}".format(started_cluster.mock_host, started_cluster.redirecting_preserving_data_port, started_cluster.bucket, format, other_values)
|
||||
run_query(instance, query)
|
||||
|
||||
query = "select *, column1*column2*column3 from s3('http://{}:{}/{}/test.csv', 'CSV', '{}')".format(started_cluster.mock_host, started_cluster.preserving_data_port, started_cluster.bucket, format)
|
||||
stdout = run_query(instance, query)
|
||||
assert list(map(str.split, stdout.splitlines())) == [
|
||||
["1", "1", "1", "1"],
|
||||
["1", "1", "1", "1"],
|
||||
["11", "11", "11", "1331"],
|
||||
]
|
||||
data = get_communication_data(started_cluster)
|
||||
received_data = data["received_data"]
|
||||
assert received_data[-1].decode() == "1,1,1\n1,1,1\n11,11,11\n"
|
||||
|
||||
|
||||
def test_multipart_put(started_cluster):
|
||||
instance = started_cluster.instances["dummy"]
|
||||
format = "column1 UInt32, column2 UInt32, column3 UInt32"
|
||||
|
||||
put_communication_data(started_cluster, "=== Multipart test ===")
|
||||
long_data = [[i, i+1, i+2] for i in range(100000)]
|
||||
long_values = "".join([ "{},{},{}\n".format(x,y,z) for x, y, z in long_data ])
|
||||
put_query = "insert into table function s3('http://{}:{}/{}/test.csv', 'CSV', '{}') format CSV".format(started_cluster.mock_host, started_cluster.multipart_preserving_data_port, started_cluster.bucket, format)
|
||||
run_query(instance, put_query, stdin=long_values)
|
||||
data = get_communication_data(started_cluster)
|
||||
assert "multipart_received_data" in data
|
||||
received_data = data["multipart_received_data"]
|
||||
assert received_data[-1].decode() == "".join([ "{},{},{}\n".format(x, y, z) for x, y, z in long_data ])
|
||||
assert 1 < data["multipart_parts"] < 10000
|
365
dbms/tests/integration/test_storage_s3/test_server.py
Normal file
365
dbms/tests/integration/test_storage_s3/test_server.py
Normal file
@ -0,0 +1,365 @@
|
||||
try:
|
||||
from BaseHTTPServer import BaseHTTPRequestHandler
|
||||
except ImportError:
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
|
||||
try:
|
||||
from BaseHTTPServer import HTTPServer
|
||||
except ImportError:
|
||||
from http.server import HTTPServer
|
||||
|
||||
try:
|
||||
import urllib.parse as urlparse
|
||||
except ImportError:
|
||||
import urlparse
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
import xml.etree.ElementTree
|
||||
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
file_handler = logging.FileHandler("/var/log/clickhouse-server/test-server.log", "a", encoding="utf-8")
|
||||
file_handler.setFormatter(logging.Formatter("%(asctime)s %(message)s"))
|
||||
logging.getLogger().addHandler(file_handler)
|
||||
logging.getLogger().addHandler(logging.StreamHandler())
|
||||
|
||||
communication_port = int(sys.argv[1])
|
||||
bucket = sys.argv[2]
|
||||
|
||||
def GetFreeTCPPortsAndIP(n):
|
||||
result = []
|
||||
sockets = []
|
||||
for i in range(n):
|
||||
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
tcp.bind((socket.gethostname(), 0))
|
||||
addr, port = tcp.getsockname()
|
||||
result.append(port)
|
||||
sockets.append(tcp)
|
||||
[ s.close() for s in sockets ]
|
||||
return result, addr
|
||||
|
||||
(
|
||||
redirecting_to_http_port,
|
||||
simple_server_port,
|
||||
preserving_data_port,
|
||||
multipart_preserving_data_port,
|
||||
redirecting_preserving_data_port
|
||||
), localhost = GetFreeTCPPortsAndIP(5)
|
||||
|
||||
data = {
|
||||
"redirecting_to_http_port": redirecting_to_http_port,
|
||||
"preserving_data_port": preserving_data_port,
|
||||
"multipart_preserving_data_port": multipart_preserving_data_port,
|
||||
"redirecting_preserving_data_port": redirecting_preserving_data_port,
|
||||
}
|
||||
|
||||
|
||||
class SimpleHTTPServerHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
logging.info("GET {}".format(self.path))
|
||||
if self.path == "/milovidov/test.csv":
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/plain")
|
||||
self.end_headers()
|
||||
data["redirect_csv_data"] = [[42, 87, 44], [55, 33, 81], [1, 0, 9]]
|
||||
self.wfile.write("".join([ "{},{},{}\n".format(*row) for row in data["redirect_csv_data"]]))
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
self.finish()
|
||||
|
||||
|
||||
class RedirectingToHTTPHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
self.send_response(307)
|
||||
self.send_header("Content-type", "text/xml")
|
||||
self.send_header("Location", "http://{}:{}/milovidov/test.csv".format(localhost, simple_server_port))
|
||||
self.end_headers()
|
||||
self.wfile.write(r"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Error>
|
||||
<Code>TemporaryRedirect</Code>
|
||||
<Message>Please re-send this request to the specified temporary endpoint.
|
||||
Continue to use the original request endpoint for future requests.</Message>
|
||||
<Endpoint>storage.yandexcloud.net</Endpoint>
|
||||
</Error>""".encode())
|
||||
self.finish()
|
||||
|
||||
|
||||
class PreservingDataHandler(BaseHTTPRequestHandler):
|
||||
protocol_version = "HTTP/1.1"
|
||||
|
||||
def parse_request(self):
|
||||
result = BaseHTTPRequestHandler.parse_request(self)
|
||||
# Adaptation to Python 3.
|
||||
if sys.version_info.major == 2 and result == True:
|
||||
expect = self.headers.get("Expect", "")
|
||||
if (expect.lower() == "100-continue" and self.protocol_version >= "HTTP/1.1" and self.request_version >= "HTTP/1.1"):
|
||||
if not self.handle_expect_100():
|
||||
return False
|
||||
return result
|
||||
|
||||
def send_response_only(self, code, message=None):
|
||||
if message is None:
|
||||
if code in self.responses:
|
||||
message = self.responses[code][0]
|
||||
else:
|
||||
message = ""
|
||||
if self.request_version != "HTTP/0.9":
|
||||
self.wfile.write("%s %d %s\r\n" % (self.protocol_version, code, message))
|
||||
|
||||
def handle_expect_100(self):
|
||||
logging.info("Received Expect-100")
|
||||
self.send_response_only(100)
|
||||
self.end_headers()
|
||||
return True
|
||||
|
||||
def do_POST(self):
|
||||
self.send_response(200)
|
||||
query = urlparse.urlparse(self.path).query
|
||||
logging.info("PreservingDataHandler POST ?" + query)
|
||||
if query == "uploads":
|
||||
post_data = r"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<hi><UploadId>TEST</UploadId></hi>""".encode()
|
||||
self.send_header("Content-length", str(len(post_data)))
|
||||
self.send_header("Content-type", "text/plain")
|
||||
self.end_headers()
|
||||
self.wfile.write(post_data)
|
||||
else:
|
||||
post_data = self.rfile.read(int(self.headers.get("Content-Length")))
|
||||
self.send_header("Content-type", "text/plain")
|
||||
self.end_headers()
|
||||
data["received_data_completed"] = True
|
||||
data["finalize_data"] = post_data
|
||||
data["finalize_data_query"] = query
|
||||
self.finish()
|
||||
|
||||
def do_PUT(self):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/plain")
|
||||
self.send_header("ETag", "hello-etag")
|
||||
self.end_headers()
|
||||
query = urlparse.urlparse(self.path).query
|
||||
path = urlparse.urlparse(self.path).path
|
||||
logging.info("Content-Length = " + self.headers.get("Content-Length"))
|
||||
logging.info("PUT " + query)
|
||||
assert self.headers.get("Content-Length")
|
||||
assert self.headers["Expect"] == "100-continue"
|
||||
put_data = self.rfile.read()
|
||||
data.setdefault("received_data", []).append(put_data)
|
||||
logging.info("PUT to {}".format(path))
|
||||
self.server.storage[path] = put_data
|
||||
self.finish()
|
||||
|
||||
def do_GET(self):
|
||||
path = urlparse.urlparse(self.path).path
|
||||
if path in self.server.storage:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/plain")
|
||||
self.send_header("Content-length", str(len(self.server.storage[path])))
|
||||
self.end_headers()
|
||||
self.wfile.write(self.server.storage[path])
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
self.finish()
|
||||
|
||||
|
||||
class MultipartPreservingDataHandler(BaseHTTPRequestHandler):
|
||||
protocol_version = "HTTP/1.1"
|
||||
|
||||
def parse_request(self):
|
||||
result = BaseHTTPRequestHandler.parse_request(self)
|
||||
# Adaptation to Python 3.
|
||||
if sys.version_info.major == 2 and result == True:
|
||||
expect = self.headers.get("Expect", "")
|
||||
if (expect.lower() == "100-continue" and self.protocol_version >= "HTTP/1.1" and self.request_version >= "HTTP/1.1"):
|
||||
if not self.handle_expect_100():
|
||||
return False
|
||||
return result
|
||||
|
||||
def send_response_only(self, code, message=None):
|
||||
if message is None:
|
||||
if code in self.responses:
|
||||
message = self.responses[code][0]
|
||||
else:
|
||||
message = ""
|
||||
if self.request_version != "HTTP/0.9":
|
||||
self.wfile.write("%s %d %s\r\n" % (self.protocol_version, code, message))
|
||||
|
||||
def handle_expect_100(self):
|
||||
logging.info("Received Expect-100")
|
||||
self.send_response_only(100)
|
||||
self.end_headers()
|
||||
return True
|
||||
|
||||
def do_POST(self):
|
||||
query = urlparse.urlparse(self.path).query
|
||||
logging.info("MultipartPreservingDataHandler POST ?" + query)
|
||||
if query == "uploads":
|
||||
self.send_response(200)
|
||||
post_data = r"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<hi><UploadId>TEST</UploadId></hi>""".encode()
|
||||
self.send_header("Content-length", str(len(post_data)))
|
||||
self.send_header("Content-type", "text/plain")
|
||||
self.end_headers()
|
||||
self.wfile.write(post_data)
|
||||
else:
|
||||
try:
|
||||
assert query == "uploadId=TEST"
|
||||
logging.info("Content-Length = " + self.headers.get("Content-Length"))
|
||||
post_data = self.rfile.read(int(self.headers.get("Content-Length")))
|
||||
root = xml.etree.ElementTree.fromstring(post_data)
|
||||
assert root.tag == "CompleteMultipartUpload"
|
||||
assert len(root) > 1
|
||||
content = ""
|
||||
for i, part in enumerate(root):
|
||||
assert part.tag == "Part"
|
||||
assert len(part) == 2
|
||||
assert part[0].tag == "PartNumber"
|
||||
assert part[1].tag == "ETag"
|
||||
assert int(part[0].text) == i + 1
|
||||
content += self.server.storage["@"+part[1].text]
|
||||
data.setdefault("multipart_received_data", []).append(content)
|
||||
data["multipart_parts"] = len(root)
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/plain")
|
||||
self.end_headers()
|
||||
logging.info("Sending 200")
|
||||
except:
|
||||
logging.error("Sending 500")
|
||||
self.send_response(500)
|
||||
self.finish()
|
||||
|
||||
def do_PUT(self):
|
||||
uid = uuid.uuid4()
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/plain")
|
||||
self.send_header("ETag", str(uid))
|
||||
self.end_headers()
|
||||
query = urlparse.urlparse(self.path).query
|
||||
path = urlparse.urlparse(self.path).path
|
||||
logging.info("Content-Length = " + self.headers.get("Content-Length"))
|
||||
logging.info("PUT " + query)
|
||||
assert self.headers.get("Content-Length")
|
||||
assert self.headers["Expect"] == "100-continue"
|
||||
put_data = self.rfile.read()
|
||||
data.setdefault("received_data", []).append(put_data)
|
||||
logging.info("PUT to {}".format(path))
|
||||
self.server.storage["@"+str(uid)] = put_data
|
||||
self.finish()
|
||||
|
||||
def do_GET(self):
|
||||
path = urlparse.urlparse(self.path).path
|
||||
if path in self.server.storage:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/plain")
|
||||
self.send_header("Content-length", str(len(self.server.storage[path])))
|
||||
self.end_headers()
|
||||
self.wfile.write(self.server.storage[path])
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
self.finish()
|
||||
|
||||
|
||||
class RedirectingPreservingDataHandler(BaseHTTPRequestHandler):
|
||||
protocol_version = "HTTP/1.1"
|
||||
|
||||
def parse_request(self):
|
||||
result = BaseHTTPRequestHandler.parse_request(self)
|
||||
# Adaptation to Python 3.
|
||||
if sys.version_info.major == 2 and result == True:
|
||||
expect = self.headers.get("Expect", "")
|
||||
if (expect.lower() == "100-continue" and self.protocol_version >= "HTTP/1.1" and self.request_version >= "HTTP/1.1"):
|
||||
if not self.handle_expect_100():
|
||||
return False
|
||||
return result
|
||||
|
||||
def send_response_only(self, code, message=None):
|
||||
if message is None:
|
||||
if code in self.responses:
|
||||
message = self.responses[code][0]
|
||||
else:
|
||||
message = ""
|
||||
if self.request_version != "HTTP/0.9":
|
||||
self.wfile.write("%s %d %s\r\n" % (self.protocol_version, code, message))
|
||||
|
||||
def handle_expect_100(self):
|
||||
logging.info("Received Expect-100")
|
||||
return True
|
||||
|
||||
def do_POST(self):
|
||||
query = urlparse.urlparse(self.path).query
|
||||
if query:
|
||||
query = "?{}".format(query)
|
||||
self.send_response(307)
|
||||
self.send_header("Content-type", "text/xml")
|
||||
self.send_header("Location", "http://{host}:{port}/{bucket}/test.csv{query}".format(host=localhost, port=preserving_data_port, bucket=bucket, query=query))
|
||||
self.end_headers()
|
||||
self.wfile.write(r"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Error>
|
||||
<Code>TemporaryRedirect</Code>
|
||||
<Message>Please re-send this request to the specified temporary endpoint.
|
||||
Continue to use the original request endpoint for future requests.</Message>
|
||||
<Endpoint>{host}:{port}</Endpoint>
|
||||
</Error>""".format(host=localhost, port=preserving_data_port).encode())
|
||||
self.finish()
|
||||
|
||||
def do_PUT(self):
|
||||
query = urlparse.urlparse(self.path).query
|
||||
if query:
|
||||
query = "?{}".format(query)
|
||||
self.send_response(307)
|
||||
self.send_header("Content-type", "text/xml")
|
||||
self.send_header("Location", "http://{host}:{port}/{bucket}/test.csv{query}".format(host=localhost, port=preserving_data_port, bucket=bucket, query=query))
|
||||
self.end_headers()
|
||||
self.wfile.write(r"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Error>
|
||||
<Code>TemporaryRedirect</Code>
|
||||
<Message>Please re-send this request to the specified temporary endpoint.
|
||||
Continue to use the original request endpoint for future requests.</Message>
|
||||
<Endpoint>{host}:{port}</Endpoint>
|
||||
</Error>""".format(host=localhost, port=preserving_data_port).encode())
|
||||
self.finish()
|
||||
|
||||
|
||||
class CommunicationServerHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(data))
|
||||
self.finish()
|
||||
|
||||
def do_PUT(self):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
logging.info(self.rfile.read())
|
||||
self.finish()
|
||||
|
||||
|
||||
servers = []
|
||||
servers.append(HTTPServer((localhost, communication_port), CommunicationServerHandler))
|
||||
servers.append(HTTPServer((localhost, redirecting_to_http_port), RedirectingToHTTPHandler))
|
||||
servers.append(HTTPServer((localhost, preserving_data_port), PreservingDataHandler))
|
||||
servers[-1].storage = {}
|
||||
servers.append(HTTPServer((localhost, multipart_preserving_data_port), MultipartPreservingDataHandler))
|
||||
servers[-1].storage = {}
|
||||
servers.append(HTTPServer((localhost, simple_server_port), SimpleHTTPServerHandler))
|
||||
servers.append(HTTPServer((localhost, redirecting_preserving_data_port), RedirectingPreservingDataHandler))
|
||||
jobs = [ threading.Thread(target=server.serve_forever) for server in servers ]
|
||||
[ job.start() for job in jobs ]
|
||||
|
||||
time.sleep(60) # Timeout
|
||||
|
||||
logging.info("Shutting down")
|
||||
[ server.shutdown() for server in servers ]
|
||||
logging.info("Joining threads")
|
||||
[ job.join() for job in jobs ]
|
||||
logging.info("Done")
|
Loading…
Reference in New Issue
Block a user