diff --git a/src/IO/WriteBufferFromHTTP.cpp b/src/IO/WriteBufferFromHTTP.cpp index f7456ad6b6c..355c42a23c9 100644 --- a/src/IO/WriteBufferFromHTTP.cpp +++ b/src/IO/WriteBufferFromHTTP.cpp @@ -11,6 +11,7 @@ WriteBufferFromHTTP::WriteBufferFromHTTP( const std::string & method, const std::string & content_type, const std::string & content_encoding, + const HTTPHeaderEntries & additional_headers, const ConnectionTimeouts & timeouts, size_t buffer_size_) : WriteBufferFromOStream(buffer_size_) @@ -28,6 +29,9 @@ WriteBufferFromHTTP::WriteBufferFromHTTP( if (!content_encoding.empty()) request.set("Content-Encoding", content_encoding); + for (const auto & header: additional_headers) + request.add(header.name, header.value); + LOG_TRACE((&Poco::Logger::get("WriteBufferToHTTP")), "Sending request to {}", uri.toString()); ostr = &session->sendRequest(request); diff --git a/src/IO/WriteBufferFromHTTP.h b/src/IO/WriteBufferFromHTTP.h index 6966bc8a5c5..ce5020dfa78 100644 --- a/src/IO/WriteBufferFromHTTP.h +++ b/src/IO/WriteBufferFromHTTP.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -22,6 +23,7 @@ public: const std::string & method = Poco::Net::HTTPRequest::HTTP_POST, // POST or PUT only const std::string & content_type = "", const std::string & content_encoding = "", + const HTTPHeaderEntries & additional_headers = {}, const ConnectionTimeouts & timeouts = {}, size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE); diff --git a/src/Storages/StorageURL.cpp b/src/Storages/StorageURL.cpp index 152dda8f360..c0ddb0bc48a 100644 --- a/src/Storages/StorageURL.cpp +++ b/src/Storages/StorageURL.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -458,6 +459,7 @@ StorageURLSink::StorageURLSink( ContextPtr context, const ConnectionTimeouts & timeouts, const CompressionMethod compression_method, + const HTTPHeaderEntries & headers, const String & http_method) : SinkToStorage(sample_block) { @@ -465,7 +467,7 @@ StorageURLSink::StorageURLSink( std::string content_encoding = toContentEncodingName(compression_method); write_buf = wrapWriteBufferWithCompressionMethod( - std::make_unique(Poco::URI(uri), http_method, content_type, content_encoding, timeouts), + std::make_unique(Poco::URI(uri), http_method, content_type, content_encoding, headers, timeouts), compression_method, 3); writer = FormatFactory::instance().getOutputFormat(format, *write_buf, sample_block, context, format_settings); @@ -530,6 +532,7 @@ public: ContextPtr context_, const ConnectionTimeouts & timeouts_, const CompressionMethod compression_method_, + const HTTPHeaderEntries & headers_, const String & http_method_) : PartitionedSink(partition_by, context_, sample_block_) , uri(uri_) @@ -539,6 +542,7 @@ public: , context(context_) , timeouts(timeouts_) , compression_method(compression_method_) + , headers(headers_) , http_method(http_method_) { } @@ -548,7 +552,7 @@ public: auto partition_path = PartitionedSink::replaceWildcards(uri, partition_id); context->getRemoteHostFilter().checkURL(Poco::URI(partition_path)); return std::make_shared( - partition_path, format, format_settings, sample_block, context, timeouts, compression_method, http_method); + partition_path, format, format_settings, sample_block, context, timeouts, compression_method, headers, http_method); } private: @@ -560,6 +564,7 @@ private: const ConnectionTimeouts timeouts; const CompressionMethod compression_method; + const HTTPHeaderEntries headers; const String http_method; }; @@ -821,6 +826,7 @@ SinkToStoragePtr IStorageURLBase::write(const ASTPtr & query, const StorageMetad context, getHTTPTimeouts(context), compression_method, + headers, http_method); } else @@ -833,6 +839,7 @@ SinkToStoragePtr IStorageURLBase::write(const ASTPtr & query, const StorageMetad context, getHTTPTimeouts(context), compression_method, + headers, http_method); } } diff --git a/src/Storages/StorageURL.h b/src/Storages/StorageURL.h index acf49f3cb71..1cfffc3e73a 100644 --- a/src/Storages/StorageURL.h +++ b/src/Storages/StorageURL.h @@ -137,6 +137,7 @@ public: ContextPtr context, const ConnectionTimeouts & timeouts, CompressionMethod compression_method, + const HTTPHeaderEntries & headers = {}, const String & method = Poco::Net::HTTPRequest::HTTP_POST); std::string getName() const override { return "StorageURLSink"; } diff --git a/tests/integration/test_storage_url_http_headers/__init__.py b/tests/integration/test_storage_url_http_headers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/integration/test_storage_url_http_headers/http_headers_echo_server.py b/tests/integration/test_storage_url_http_headers/http_headers_echo_server.py new file mode 100644 index 00000000000..b1a3f6777b1 --- /dev/null +++ b/tests/integration/test_storage_url_http_headers/http_headers_echo_server.py @@ -0,0 +1,31 @@ +import http.server + +RESULT_PATH = "/headers.txt" + + +class RequestHandler(http.server.BaseHTTPRequestHandler): + def log_message(self, *args): + with open(RESULT_PATH, "w") as f: + f.write(self.headers.as_string()) + + def do_POST(self): + self.rfile.read1() + self.send_response(200) + self.end_headers() + self.wfile.write(b'{"status":"ok"}') + + +if __name__ == "__main__": + with open(RESULT_PATH, "w") as f: + f.write("") + httpd = http.server.HTTPServer( + ( + "localhost", + 8000, + ), + RequestHandler, + ) + try: + httpd.serve_forever() + finally: + httpd.server_close() diff --git a/tests/integration/test_storage_url_http_headers/test.py b/tests/integration/test_storage_url_http_headers/test.py new file mode 100644 index 00000000000..3bbf5ec81c9 --- /dev/null +++ b/tests/integration/test_storage_url_http_headers/test.py @@ -0,0 +1,66 @@ +import pytest +import os +import time + +from . import http_headers_echo_server + +from helpers.cluster import ClickHouseCluster + +cluster = ClickHouseCluster(__file__) +server = cluster.add_instance("node") + + +def run_echo_server(): + script_dir = os.path.dirname(os.path.realpath(__file__)) + + server.copy_file_to_container( + os.path.join(script_dir, "http_headers_echo_server.py"), + "/http_headers_echo_server.py", + ) + + server.exec_in_container( + [ + "bash", + "-c", + "python3 /http_headers_echo_server.py > /http_headers_echo.server.log 2>&1", + ], + detach=True, + user="root", + ) + + for _ in range(0, 10): + ping_response = server.exec_in_container( + ["curl", "-s", f"http://localhost:8000/"], + nothrow=True, + ) + + if "html" in ping_response: + return + + print(ping_response) + + raise Exception("Echo server is not responding") + + +@pytest.fixture(scope="module") +def started_cluster(): + try: + cluster.start() + run_echo_server() + yield cluster + finally: + cluster.shutdown() + + +def test_storage_url_http_headers(started_cluster): + query = "INSERT INTO TABLE FUNCTION url('http://localhost:8000/', JSON, 'a UInt64', headers('X-My-Custom-Header'='test-header')) VALUES (1)" + + server.query(query) + + result = server.exec_in_container( + ["cat", http_headers_echo_server.RESULT_PATH], user="root" + ) + + print(result) + + assert "X-My-Custom-Header: test-header" in result