diff --git a/dbms/src/IO/ReadBufferFromS3.cpp b/dbms/src/IO/ReadBufferFromS3.cpp index 7fcb7a0ca41..ae09f0fb189 100644 --- a/dbms/src/IO/ReadBufferFromS3.cpp +++ b/dbms/src/IO/ReadBufferFromS3.cpp @@ -5,11 +5,11 @@ #include -#define DEFAULT_S3_MAX_FOLLOW_GET_REDIRECT 2 - namespace DB { +const int DEFAULT_S3_MAX_FOLLOW_GET_REDIRECT = 2; + ReadBufferFromS3::ReadBufferFromS3(Poco::URI uri_, const ConnectionTimeouts & timeouts, const Poco::Net::HTTPBasicCredentials & credentials, diff --git a/dbms/src/IO/WriteBufferFromS3.cpp b/dbms/src/IO/WriteBufferFromS3.cpp index 5b6f9fdff4c..1ef6f3b19a0 100644 --- a/dbms/src/IO/WriteBufferFromS3.cpp +++ b/dbms/src/IO/WriteBufferFromS3.cpp @@ -11,12 +11,13 @@ #include -#define DEFAULT_S3_MAX_FOLLOW_PUT_REDIRECT 2 -#define S3_SOFT_MAX_PARTS 10000 - namespace DB { +const int DEFAULT_S3_MAX_FOLLOW_PUT_REDIRECT = 2; +const int S3_WARN_MAX_PARTS = 10000; + + namespace ErrorCodes { extern const int INCORRECT_DATA; @@ -92,34 +93,33 @@ void WriteBufferFromS3::initiate() { // See https://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadInitiate.html Poco::Net::HTTPResponse response; - std::unique_ptr request; + std::unique_ptr request_ptr; HTTPSessionPtr session; std::istream * istr = nullptr; /// owned by session Poco::URI initiate_uri = uri; initiate_uri.setRawQuery("uploads"); - auto params = uri.getQueryParameters(); - for (auto it = params.begin(); it != params.end(); ++it) + for (auto & param: uri.getQueryParameters()) { - initiate_uri.addQueryParameter(it->first, it->second); + initiate_uri.addQueryParameter(param.first, param.second); } for (int i = 0; i < DEFAULT_S3_MAX_FOLLOW_PUT_REDIRECT; ++i) { session = makeHTTPSession(initiate_uri, timeouts); - request = std::make_unique(Poco::Net::HTTPRequest::HTTP_POST, initiate_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1); - request->setHost(initiate_uri.getHost()); // use original, not resolved host name in header + request_ptr = std::make_unique(Poco::Net::HTTPRequest::HTTP_POST, initiate_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1); + request_ptr->setHost(initiate_uri.getHost()); // use original, not resolved host name in header if (auth_request.hasCredentials()) { Poco::Net::HTTPBasicCredentials credentials(auth_request); - credentials.authenticate(*request); + credentials.authenticate(*request_ptr); } - request->setContentLength(0); + request_ptr->setContentLength(0); LOG_TRACE((&Logger::get("WriteBufferFromS3")), "Sending request to " << initiate_uri.toString()); - session->sendRequest(*request); + session->sendRequest(*request_ptr); istr = &session->receiveResponse(response); @@ -134,7 +134,7 @@ void WriteBufferFromS3::initiate() initiate_uri = location_iterator->second; } - assertResponseIsOk(*request, response, *istr); + assertResponseIsOk(*request_ptr, response, *istr); Poco::XML::InputSource src(*istr); Poco::XML::DOMParser parser; @@ -156,37 +156,38 @@ void WriteBufferFromS3::writePart(const String & data) { // See https://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadUploadPart.html Poco::Net::HTTPResponse response; - std::unique_ptr request; + std::unique_ptr request_ptr; HTTPSessionPtr session; std::istream * istr = nullptr; /// owned by session Poco::URI part_uri = uri; part_uri.addQueryParameter("partNumber", std::to_string(part_tags.size() + 1)); part_uri.addQueryParameter("uploadId", upload_id); - if (part_tags.size() == S3_SOFT_MAX_PARTS) + if (part_tags.size() == S3_WARN_MAX_PARTS) { + // Don't throw exception here by ourselves but leave the decision to take by S3 server. LOG_WARNING(&Logger::get("WriteBufferFromS3"), "Maximum part number in S3 protocol has reached (too much parts). Server may not accept this whole upload."); } for (int i = 0; i < DEFAULT_S3_MAX_FOLLOW_PUT_REDIRECT; ++i) { session = makeHTTPSession(part_uri, timeouts); - request = std::make_unique(Poco::Net::HTTPRequest::HTTP_PUT, part_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1); - request->setHost(part_uri.getHost()); // use original, not resolved host name in header + request_ptr = std::make_unique(Poco::Net::HTTPRequest::HTTP_PUT, part_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1); + request_ptr->setHost(part_uri.getHost()); // use original, not resolved host name in header if (auth_request.hasCredentials()) { Poco::Net::HTTPBasicCredentials credentials(auth_request); - credentials.authenticate(*request); + credentials.authenticate(*request_ptr); } - request->setExpectContinue(true); + request_ptr->setExpectContinue(true); - request->setContentLength(data.size()); + request_ptr->setContentLength(data.size()); LOG_TRACE((&Logger::get("WriteBufferFromS3")), "Sending request to " << part_uri.toString()); - std::ostream & ostr = session->sendRequest(*request); + std::ostream & ostr = session->sendRequest(*request_ptr); if (session->peekResponse(response)) { // Received 100-continue. @@ -206,7 +207,7 @@ void WriteBufferFromS3::writePart(const String & data) part_uri = location_iterator->second; } - assertResponseIsOk(*request, response, *istr); + assertResponseIsOk(*request_ptr, response, *istr); auto etag_iterator = response.find("ETag"); if (etag_iterator == response.end()) @@ -221,7 +222,7 @@ void WriteBufferFromS3::complete() { // See https://docs.aws.amazon.com/AmazonS3/latest/API/mpUploadComplete.html Poco::Net::HTTPResponse response; - std::unique_ptr request; + std::unique_ptr request_ptr; HTTPSessionPtr session; std::istream * istr = nullptr; /// owned by session Poco::URI complete_uri = uri; @@ -244,22 +245,22 @@ void WriteBufferFromS3::complete() for (int i = 0; i < DEFAULT_S3_MAX_FOLLOW_PUT_REDIRECT; ++i) { session = makeHTTPSession(complete_uri, timeouts); - request = std::make_unique(Poco::Net::HTTPRequest::HTTP_POST, complete_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1); - request->setHost(complete_uri.getHost()); // use original, not resolved host name in header + request_ptr = std::make_unique(Poco::Net::HTTPRequest::HTTP_POST, complete_uri.getPathAndQuery(), Poco::Net::HTTPRequest::HTTP_1_1); + request_ptr->setHost(complete_uri.getHost()); // use original, not resolved host name in header if (auth_request.hasCredentials()) { Poco::Net::HTTPBasicCredentials credentials(auth_request); - credentials.authenticate(*request); + credentials.authenticate(*request_ptr); } - request->setExpectContinue(true); + request_ptr->setExpectContinue(true); - request->setContentLength(data.size()); + request_ptr->setContentLength(data.size()); LOG_TRACE((&Logger::get("WriteBufferFromS3")), "Sending request to " << complete_uri.toString()); - std::ostream & ostr = session->sendRequest(*request); + std::ostream & ostr = session->sendRequest(*request_ptr); if (session->peekResponse(response)) { // Received 100-continue. @@ -279,7 +280,7 @@ void WriteBufferFromS3::complete() complete_uri = location_iterator->second; } - assertResponseIsOk(*request, response, *istr); + assertResponseIsOk(*request_ptr, response, *istr); } } diff --git a/dbms/tests/integration/test_storage_s3/test.py b/dbms/tests/integration/test_storage_s3/test.py index 2013daa6ae6..88be4640388 100644 --- a/dbms/tests/integration/test_storage_s3/test.py +++ b/dbms/tests/integration/test_storage_s3/test.py @@ -15,7 +15,7 @@ logging.getLogger().addHandler(logging.StreamHandler()) def get_communication_data(started_cluster): - conn = httplib.HTTPConnection(started_cluster.instances['dummy'].ip_address, started_cluster.communication_port) + conn = httplib.HTTPConnection(started_cluster.instances["dummy"].ip_address, started_cluster.communication_port) conn.request("GET", "/") r = conn.getresponse() raw_data = r.read() @@ -24,7 +24,7 @@ def get_communication_data(started_cluster): def put_communication_data(started_cluster, body): - conn = httplib.HTTPConnection(started_cluster.instances['dummy'].ip_address, started_cluster.communication_port) + conn = httplib.HTTPConnection(started_cluster.instances["dummy"].ip_address, started_cluster.communication_port) conn.request("PUT", "/", body) r = conn.getresponse() conn.close() @@ -34,29 +34,29 @@ def put_communication_data(started_cluster, body): def started_cluster(): try: cluster = ClickHouseCluster(__file__) - instance = cluster.add_instance('dummy', config_dir="configs", main_configs=['configs/min_chunk_size.xml']) + instance = cluster.add_instance("dummy", config_dir="configs", main_configs=["configs/min_chunk_size.xml"]) cluster.start() cluster.communication_port = 10000 - instance.copy_file_to_container(os.path.join(os.path.dirname(__file__), 'test_server.py'), 'test_server.py') - cluster.bucket = 'abc' - instance.exec_in_container(['python', 'test_server.py', str(cluster.communication_port), cluster.bucket], detach=True) + instance.copy_file_to_container(os.path.join(os.path.dirname(__file__), "test_server.py"), "test_server.py") + cluster.bucket = "abc" + instance.exec_in_container(["python", "test_server.py", str(cluster.communication_port), cluster.bucket], detach=True) cluster.mock_host = instance.ip_address for i in range(10): try: data = get_communication_data(cluster) - cluster.redirecting_to_http_port = data['redirecting_to_http_port'] - cluster.preserving_data_port = data['preserving_data_port'] - cluster.multipart_preserving_data_port = data['multipart_preserving_data_port'] - cluster.redirecting_preserving_data_port = data['redirecting_preserving_data_port'] + cluster.redirecting_to_http_port = data["redirecting_to_http_port"] + cluster.preserving_data_port = data["preserving_data_port"] + cluster.multipart_preserving_data_port = data["multipart_preserving_data_port"] + cluster.redirecting_preserving_data_port = data["redirecting_preserving_data_port"] except: logging.error(traceback.format_exc()) time.sleep(0.5) else: break else: - assert False, 'Could not initialize mock server' + assert False, "Could not initialize mock server" yield cluster @@ -65,92 +65,97 @@ def started_cluster(): def run_query(instance, query, stdin=None): - logging.info('Running query "{}"...'.format(query)) + logging.info("Running query '{}'...".format(query)) result = instance.query(query, stdin=stdin) - logging.info('Query finished') + logging.info("Query finished") return result -def test_get_with_redirect(started_cluster): - instance = started_cluster.instances['dummy'] - format = 'column1 UInt32, column2 UInt32, column3 UInt32' - put_communication_data(started_cluster, '=== Get with redirect test ===') +def test_get_with_redirect(started_cluster): + instance = started_cluster.instances["dummy"] + format = "column1 UInt32, column2 UInt32, column3 UInt32" + + put_communication_data(started_cluster, "=== Get with redirect test ===") query = "select *, column1*column2*column3 from s3('http://{}:{}/', 'CSV', '{}')".format(started_cluster.mock_host, started_cluster.redirecting_to_http_port, format) stdout = run_query(instance, query) assert list(map(str.split, stdout.splitlines())) == [ - ['42', '87', '44', '160776'], - ['55', '33', '81', '147015'], - ['1', '0', '9', '0'], + ["42", "87", "44", "160776"], + ["55", "33", "81", "147015"], + ["1", "0", "9", "0"], ] -def test_put(started_cluster): - instance = started_cluster.instances['dummy'] - format = 'column1 UInt32, column2 UInt32, column3 UInt32' - logging.info('Phase 3') - put_communication_data(started_cluster, '=== Put test ===') - values = '(1, 2, 3), (3, 2, 1), (78, 43, 45)' +def test_put(started_cluster): + instance = started_cluster.instances["dummy"] + format = "column1 UInt32, column2 UInt32, column3 UInt32" + + logging.info("Phase 3") + put_communication_data(started_cluster, "=== Put test ===") + values = "(1, 2, 3), (3, 2, 1), (78, 43, 45)" put_query = "insert into table function s3('http://{}:{}/{}/test.csv', 'CSV', '{}') values {}".format(started_cluster.mock_host, started_cluster.preserving_data_port, started_cluster.bucket, format, values) run_query(instance, put_query) data = get_communication_data(started_cluster) - received_data_completed = data['received_data_completed'] - received_data = data['received_data'] - finalize_data = data['finalize_data'] - finalize_data_query = data['finalize_data_query'] - assert received_data[-1].decode() == '1,2,3\n3,2,1\n78,43,45\n' + received_data_completed = data["received_data_completed"] + received_data = data["received_data"] + finalize_data = data["finalize_data"] + finalize_data_query = data["finalize_data_query"] + assert received_data[-1].decode() == "1,2,3\n3,2,1\n78,43,45\n" assert received_data_completed - assert finalize_data == '1hello-etag' - assert finalize_data_query == 'uploadId=TEST' + assert finalize_data == "1hello-etag" + assert finalize_data_query == "uploadId=TEST" + def test_put_csv(started_cluster): - instance = started_cluster.instances['dummy'] - format = 'column1 UInt32, column2 UInt32, column3 UInt32' + instance = started_cluster.instances["dummy"] + format = "column1 UInt32, column2 UInt32, column3 UInt32" - put_communication_data(started_cluster, '=== Put test CSV ===') + put_communication_data(started_cluster, "=== Put test CSV ===") put_query = "insert into table function s3('http://{}:{}/{}/test.csv', 'CSV', '{}') format CSV".format(started_cluster.mock_host, started_cluster.preserving_data_port, started_cluster.bucket, format) - csv_data = '8,9,16\n11,18,13\n22,14,2\n' + csv_data = "8,9,16\n11,18,13\n22,14,2\n" run_query(instance, put_query, stdin=csv_data) data = get_communication_data(started_cluster) - received_data_completed = data['received_data_completed'] - received_data = data['received_data'] - finalize_data = data['finalize_data'] - finalize_data_query = data['finalize_data_query'] + received_data_completed = data["received_data_completed"] + received_data = data["received_data"] + finalize_data = data["finalize_data"] + finalize_data_query = data["finalize_data_query"] assert received_data[-1].decode() == csv_data assert received_data_completed - assert finalize_data == '1hello-etag' - assert finalize_data_query == 'uploadId=TEST' + assert finalize_data == "1hello-etag" + assert finalize_data_query == "uploadId=TEST" + def test_put_with_redirect(started_cluster): - instance = started_cluster.instances['dummy'] - format = 'column1 UInt32, column2 UInt32, column3 UInt32' + instance = started_cluster.instances["dummy"] + format = "column1 UInt32, column2 UInt32, column3 UInt32" - put_communication_data(started_cluster, '=== Put with redirect test ===') - other_values = '(1, 1, 1), (1, 1, 1), (11, 11, 11)' + put_communication_data(started_cluster, "=== Put with redirect test ===") + other_values = "(1, 1, 1), (1, 1, 1), (11, 11, 11)" query = "insert into table function s3('http://{}:{}/{}/test.csv', 'CSV', '{}') values {}".format(started_cluster.mock_host, started_cluster.redirecting_preserving_data_port, started_cluster.bucket, format, other_values) run_query(instance, query) query = "select *, column1*column2*column3 from s3('http://{}:{}/{}/test.csv', 'CSV', '{}')".format(started_cluster.mock_host, started_cluster.preserving_data_port, started_cluster.bucket, format) stdout = run_query(instance, query) assert list(map(str.split, stdout.splitlines())) == [ - ['1', '1', '1', '1'], - ['1', '1', '1', '1'], - ['11', '11', '11', '1331'], + ["1", "1", "1", "1"], + ["1", "1", "1", "1"], + ["11", "11", "11", "1331"], ] data = get_communication_data(started_cluster) - received_data = data['received_data'] - assert received_data[-1].decode() == '1,1,1\n1,1,1\n11,11,11\n' + received_data = data["received_data"] + assert received_data[-1].decode() == "1,1,1\n1,1,1\n11,11,11\n" + def test_multipart_put(started_cluster): - instance = started_cluster.instances['dummy'] - format = 'column1 UInt32, column2 UInt32, column3 UInt32' + instance = started_cluster.instances["dummy"] + format = "column1 UInt32, column2 UInt32, column3 UInt32" - put_communication_data(started_cluster, '=== Multipart test ===') + put_communication_data(started_cluster, "=== Multipart test ===") long_data = [[i, i+1, i+2] for i in range(100000)] - long_values = ''.join([ '{},{},{}\n'.format(x,y,z) for x, y, z in long_data ]) + long_values = "".join([ "{},{},{}\n".format(x,y,z) for x, y, z in long_data ]) put_query = "insert into table function s3('http://{}:{}/{}/test.csv', 'CSV', '{}') format CSV".format(started_cluster.mock_host, started_cluster.multipart_preserving_data_port, started_cluster.bucket, format) run_query(instance, put_query, stdin=long_values) data = get_communication_data(started_cluster) - assert 'multipart_received_data' in data - received_data = data['multipart_received_data'] - assert received_data[-1].decode() == ''.join([ '{},{},{}\n'.format(x, y, z) for x, y, z in long_data ]) - assert 1 < data['multipart_parts'] < 10000 + assert "multipart_received_data" in data + received_data = data["multipart_received_data"] + assert received_data[-1].decode() == "".join([ "{},{},{}\n".format(x, y, z) for x, y, z in long_data ]) + assert 1 < data["multipart_parts"] < 10000 diff --git a/dbms/tests/integration/test_storage_s3/test_server.py b/dbms/tests/integration/test_storage_s3/test_server.py index 3c10445566a..09dfa1ca958 100644 --- a/dbms/tests/integration/test_storage_s3/test_server.py +++ b/dbms/tests/integration/test_storage_s3/test_server.py @@ -25,8 +25,8 @@ import xml.etree.ElementTree logging.getLogger().setLevel(logging.INFO) -file_handler = logging.FileHandler('/var/log/clickhouse-server/test-server.log', 'a', encoding='utf-8') -file_handler.setFormatter(logging.Formatter('%(asctime)s %(message)s')) +file_handler = logging.FileHandler("/var/log/clickhouse-server/test-server.log", "a", encoding="utf-8") +file_handler.setFormatter(logging.Formatter("%(asctime)s %(message)s")) logging.getLogger().addHandler(file_handler) logging.getLogger().addHandler(logging.StreamHandler()) @@ -54,21 +54,21 @@ def GetFreeTCPPortsAndIP(n): ), localhost = GetFreeTCPPortsAndIP(5) data = { - 'redirecting_to_http_port': redirecting_to_http_port, - 'preserving_data_port': preserving_data_port, - 'multipart_preserving_data_port': multipart_preserving_data_port, - 'redirecting_preserving_data_port': redirecting_preserving_data_port, + "redirecting_to_http_port": redirecting_to_http_port, + "preserving_data_port": preserving_data_port, + "multipart_preserving_data_port": multipart_preserving_data_port, + "redirecting_preserving_data_port": redirecting_preserving_data_port, } class SimpleHTTPServerHandler(BaseHTTPRequestHandler): def do_GET(self): - logging.info('GET {}'.format(self.path)) - if self.path == '/milovidov/test.csv': + logging.info("GET {}".format(self.path)) + if self.path == "/milovidov/test.csv": self.send_response(200) - self.send_header('Content-type', 'text/plain') + self.send_header("Content-type", "text/plain") self.end_headers() - self.wfile.write('42,87,44\n55,33,81\n1,0,9\n') + self.wfile.write("42,87,44\n55,33,81\n1,0,9\n") else: self.send_response(404) self.end_headers() @@ -78,27 +78,27 @@ class SimpleHTTPServerHandler(BaseHTTPRequestHandler): class RedirectingToHTTPHandler(BaseHTTPRequestHandler): def do_GET(self): self.send_response(307) - self.send_header('Content-type', 'text/xml') - self.send_header('Location', 'http://{}:{}/milovidov/test.csv'.format(localhost, simple_server_port)) + self.send_header("Content-type", "text/xml") + self.send_header("Location", "http://{}:{}/milovidov/test.csv".format(localhost, simple_server_port)) self.end_headers() - self.wfile.write(r''' + self.wfile.write(r""" TemporaryRedirect Please re-send this request to the specified temporary endpoint. Continue to use the original request endpoint for future requests. storage.yandexcloud.net -'''.encode()) +""".encode()) self.finish() class PreservingDataHandler(BaseHTTPRequestHandler): - protocol_version = 'HTTP/1.1' + protocol_version = "HTTP/1.1" def parse_request(self): result = BaseHTTPRequestHandler.parse_request(self) # Adaptation to Python 3. if sys.version_info.major == 2 and result == True: - expect = self.headers.get('Expect', "") + expect = self.headers.get("Expect", "") if (expect.lower() == "100-continue" and self.protocol_version >= "HTTP/1.1" and self.request_version >= "HTTP/1.1"): if not self.handle_expect_100(): return False @@ -109,12 +109,12 @@ class PreservingDataHandler(BaseHTTPRequestHandler): if code in self.responses: message = self.responses[code][0] else: - message = '' - if self.request_version != 'HTTP/0.9': + message = "" + if self.request_version != "HTTP/0.9": self.wfile.write("%s %d %s\r\n" % (self.protocol_version, code, message)) def handle_expect_100(self): - logging.info('Received Expect-100') + logging.info("Received Expect-100") self.send_response_only(100) self.end_headers() return True @@ -122,37 +122,37 @@ class PreservingDataHandler(BaseHTTPRequestHandler): def do_POST(self): self.send_response(200) query = urlparse.urlparse(self.path).query - logging.info('PreservingDataHandler POST ?' + query) - if query == 'uploads': - post_data = r''' -TEST'''.encode() - self.send_header('Content-length', str(len(post_data))) - self.send_header('Content-type', 'text/plain') + logging.info("PreservingDataHandler POST ?" + query) + if query == "uploads": + post_data = r""" +TEST""".encode() + self.send_header("Content-length", str(len(post_data))) + self.send_header("Content-type", "text/plain") self.end_headers() self.wfile.write(post_data) else: - post_data = self.rfile.read(int(self.headers.get('Content-Length'))) - self.send_header('Content-type', 'text/plain') + post_data = self.rfile.read(int(self.headers.get("Content-Length"))) + self.send_header("Content-type", "text/plain") self.end_headers() - data['received_data_completed'] = True - data['finalize_data'] = post_data - data['finalize_data_query'] = query + data["received_data_completed"] = True + data["finalize_data"] = post_data + data["finalize_data_query"] = query self.finish() def do_PUT(self): self.send_response(200) - self.send_header('Content-type', 'text/plain') - self.send_header('ETag', 'hello-etag') + self.send_header("Content-type", "text/plain") + self.send_header("ETag", "hello-etag") self.end_headers() query = urlparse.urlparse(self.path).query path = urlparse.urlparse(self.path).path - logging.info('Content-Length = ' + self.headers.get('Content-Length')) - logging.info('PUT ' + query) - assert self.headers.get('Content-Length') - assert self.headers['Expect'] == '100-continue' + logging.info("Content-Length = " + self.headers.get("Content-Length")) + logging.info("PUT " + query) + assert self.headers.get("Content-Length") + assert self.headers["Expect"] == "100-continue" put_data = self.rfile.read() - data.setdefault('received_data', []).append(put_data) - logging.info('PUT to {}'.format(path)) + data.setdefault("received_data", []).append(put_data) + logging.info("PUT to {}".format(path)) self.server.storage[path] = put_data self.finish() @@ -160,8 +160,8 @@ class PreservingDataHandler(BaseHTTPRequestHandler): path = urlparse.urlparse(self.path).path if path in self.server.storage: self.send_response(200) - self.send_header('Content-type', 'text/plain') - self.send_header('Content-length', str(len(self.server.storage[path]))) + self.send_header("Content-type", "text/plain") + self.send_header("Content-length", str(len(self.server.storage[path]))) self.end_headers() self.wfile.write(self.server.storage[path]) else: @@ -171,13 +171,13 @@ class PreservingDataHandler(BaseHTTPRequestHandler): class MultipartPreservingDataHandler(BaseHTTPRequestHandler): - protocol_version = 'HTTP/1.1' + protocol_version = "HTTP/1.1" def parse_request(self): result = BaseHTTPRequestHandler.parse_request(self) # Adaptation to Python 3. if sys.version_info.major == 2 and result == True: - expect = self.headers.get('Expect', "") + expect = self.headers.get("Expect", "") if (expect.lower() == "100-continue" and self.protocol_version >= "HTTP/1.1" and self.request_version >= "HTTP/1.1"): if not self.handle_expect_100(): return False @@ -188,78 +188,78 @@ class MultipartPreservingDataHandler(BaseHTTPRequestHandler): if code in self.responses: message = self.responses[code][0] else: - message = '' - if self.request_version != 'HTTP/0.9': + message = "" + if self.request_version != "HTTP/0.9": self.wfile.write("%s %d %s\r\n" % (self.protocol_version, code, message)) def handle_expect_100(self): - logging.info('Received Expect-100') + logging.info("Received Expect-100") self.send_response_only(100) self.end_headers() return True def do_POST(self): query = urlparse.urlparse(self.path).query - logging.info('MultipartPreservingDataHandler POST ?' + query) - if query == 'uploads': + logging.info("MultipartPreservingDataHandler POST ?" + query) + if query == "uploads": self.send_response(200) - post_data = r''' -TEST'''.encode() - self.send_header('Content-length', str(len(post_data))) - self.send_header('Content-type', 'text/plain') + post_data = r""" +TEST""".encode() + self.send_header("Content-length", str(len(post_data))) + self.send_header("Content-type", "text/plain") self.end_headers() self.wfile.write(post_data) else: try: - assert query == 'uploadId=TEST' - logging.info('Content-Length = ' + self.headers.get('Content-Length')) - post_data = self.rfile.read(int(self.headers.get('Content-Length'))) + assert query == "uploadId=TEST" + logging.info("Content-Length = " + self.headers.get("Content-Length")) + post_data = self.rfile.read(int(self.headers.get("Content-Length"))) root = xml.etree.ElementTree.fromstring(post_data) - assert root.tag == 'CompleteMultipartUpload' + assert root.tag == "CompleteMultipartUpload" assert len(root) > 1 - content = '' + content = "" for i, part in enumerate(root): - assert part.tag == 'Part' + assert part.tag == "Part" assert len(part) == 2 - assert part[0].tag == 'PartNumber' - assert part[1].tag == 'ETag' + assert part[0].tag == "PartNumber" + assert part[1].tag == "ETag" assert int(part[0].text) == i + 1 - content += self.server.storage['@'+part[1].text] - data.setdefault('multipart_received_data', []).append(content) - data['multipart_parts'] = len(root) + content += self.server.storage["@"+part[1].text] + data.setdefault("multipart_received_data", []).append(content) + data["multipart_parts"] = len(root) self.send_response(200) - self.send_header('Content-type', 'text/plain') + self.send_header("Content-type", "text/plain") self.end_headers() - logging.info('Sending 200') + logging.info("Sending 200") except: - logging.error('Sending 500') + logging.error("Sending 500") self.send_response(500) self.finish() def do_PUT(self): uid = uuid.uuid4() self.send_response(200) - self.send_header('Content-type', 'text/plain') - self.send_header('ETag', str(uid)) + self.send_header("Content-type", "text/plain") + self.send_header("ETag", str(uid)) self.end_headers() query = urlparse.urlparse(self.path).query path = urlparse.urlparse(self.path).path - logging.info('Content-Length = ' + self.headers.get('Content-Length')) - logging.info('PUT ' + query) - assert self.headers.get('Content-Length') - assert self.headers['Expect'] == '100-continue' + logging.info("Content-Length = " + self.headers.get("Content-Length")) + logging.info("PUT " + query) + assert self.headers.get("Content-Length") + assert self.headers["Expect"] == "100-continue" put_data = self.rfile.read() - data.setdefault('received_data', []).append(put_data) - logging.info('PUT to {}'.format(path)) - self.server.storage['@'+str(uid)] = put_data + data.setdefault("received_data", []).append(put_data) + logging.info("PUT to {}".format(path)) + self.server.storage["@"+str(uid)] = put_data self.finish() def do_GET(self): path = urlparse.urlparse(self.path).path if path in self.server.storage: self.send_response(200) - self.send_header('Content-type', 'text/plain') - self.send_header('Content-length', str(len(self.server.storage[path]))) + self.send_header("Content-type", "text/plain") + self.send_header("Content-length", str(len(self.server.storage[path]))) self.end_headers() self.wfile.write(self.server.storage[path]) else: @@ -269,13 +269,13 @@ class MultipartPreservingDataHandler(BaseHTTPRequestHandler): class RedirectingPreservingDataHandler(BaseHTTPRequestHandler): - protocol_version = 'HTTP/1.1' + protocol_version = "HTTP/1.1" def parse_request(self): result = BaseHTTPRequestHandler.parse_request(self) # Adaptation to Python 3. if sys.version_info.major == 2 and result == True: - expect = self.headers.get('Expect', "") + expect = self.headers.get("Expect", "") if (expect.lower() == "100-continue" and self.protocol_version >= "HTTP/1.1" and self.request_version >= "HTTP/1.1"): if not self.handle_expect_100(): return False @@ -286,46 +286,46 @@ class RedirectingPreservingDataHandler(BaseHTTPRequestHandler): if code in self.responses: message = self.responses[code][0] else: - message = '' - if self.request_version != 'HTTP/0.9': + message = "" + if self.request_version != "HTTP/0.9": self.wfile.write("%s %d %s\r\n" % (self.protocol_version, code, message)) def handle_expect_100(self): - logging.info('Received Expect-100') + logging.info("Received Expect-100") return True def do_POST(self): query = urlparse.urlparse(self.path).query if query: - query = '?{}'.format(query) + query = "?{}".format(query) self.send_response(307) - self.send_header('Content-type', 'text/xml') - self.send_header('Location', 'http://{host}:{port}/{bucket}/test.csv{query}'.format(host=localhost, port=preserving_data_port, bucket=bucket, query=query)) + self.send_header("Content-type", "text/xml") + self.send_header("Location", "http://{host}:{port}/{bucket}/test.csv{query}".format(host=localhost, port=preserving_data_port, bucket=bucket, query=query)) self.end_headers() - self.wfile.write(r''' + self.wfile.write(r""" TemporaryRedirect Please re-send this request to the specified temporary endpoint. Continue to use the original request endpoint for future requests. {host}:{port} -'''.format(host=localhost, port=preserving_data_port).encode()) +""".format(host=localhost, port=preserving_data_port).encode()) self.finish() def do_PUT(self): query = urlparse.urlparse(self.path).query if query: - query = '?{}'.format(query) + query = "?{}".format(query) self.send_response(307) - self.send_header('Content-type', 'text/xml') - self.send_header('Location', 'http://{host}:{port}/{bucket}/test.csv{query}'.format(host=localhost, port=preserving_data_port, bucket=bucket, query=query)) + self.send_header("Content-type", "text/xml") + self.send_header("Location", "http://{host}:{port}/{bucket}/test.csv{query}".format(host=localhost, port=preserving_data_port, bucket=bucket, query=query)) self.end_headers() - self.wfile.write(r''' + self.wfile.write(r""" TemporaryRedirect Please re-send this request to the specified temporary endpoint. Continue to use the original request endpoint for future requests. {host}:{port} -'''.format(host=localhost, port=preserving_data_port).encode()) +""".format(host=localhost, port=preserving_data_port).encode()) self.finish() @@ -357,8 +357,8 @@ jobs = [ threading.Thread(target=server.serve_forever) for server in servers ] time.sleep(60) # Timeout -logging.info('Shutting down') +logging.info("Shutting down") [ server.shutdown() for server in servers ] -logging.info('Joining threads') +logging.info("Joining threads") [ job.join() for job in jobs ] -logging.info('Done') +logging.info("Done")