diff --git a/src/Storages/StorageURL.cpp b/src/Storages/StorageURL.cpp index 527458ab668..382466c32d4 100644 --- a/src/Storages/StorageURL.cpp +++ b/src/Storages/StorageURL.cpp @@ -128,6 +128,17 @@ namespace try { + std::string user_info = request_uri.getUserInfo(); + if (!user_info.empty()) + { + std::size_t n = user_info.find(':'); + if (n != std::string::npos) + { + credentials.setUsername(user_info.substr(0, n)); + credentials.setPassword(user_info.substr(n+1)); + } + } + read_buf = wrapReadBufferWithCompressionMethod( std::make_unique( request_uri, diff --git a/tests/queries/0_stateless/02126_url_auth.python b/tests/queries/0_stateless/02126_url_auth.python new file mode 100644 index 00000000000..60009624c76 --- /dev/null +++ b/tests/queries/0_stateless/02126_url_auth.python @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 + +import socket +import csv +import sys +import tempfile +import threading +import os +import traceback +import urllib.request +import subprocess +from io import StringIO +from http.server import BaseHTTPRequestHandler, HTTPServer + +def is_ipv6(host): + try: + socket.inet_aton(host) + return False + except: + return True + +def get_local_port(host, ipv6): + if ipv6: + family = socket.AF_INET6 + else: + family = socket.AF_INET + + with socket.socket(family) as fd: + fd.bind((host, 0)) + return fd.getsockname()[1] + +CLICKHOUSE_HOST = os.environ.get('CLICKHOUSE_HOST', '127.0.0.1') +CLICKHOUSE_PORT_HTTP = os.environ.get('CLICKHOUSE_PORT_HTTP', '8123') + +##################################################################################### +# This test starts an HTTP server and serves data to clickhouse url-engine based table. +# In order for it to work ip+port of http server (given below) should be +# accessible from clickhouse server. +##################################################################################### + +# IP-address of this host accessible from the outside world. Get the first one +HTTP_SERVER_HOST = subprocess.check_output(['hostname', '-i']).decode('utf-8').strip().split()[0] +IS_IPV6 = is_ipv6(HTTP_SERVER_HOST) +HTTP_SERVER_PORT = get_local_port(HTTP_SERVER_HOST, IS_IPV6) + +# IP address and port of the HTTP server started from this script. +HTTP_SERVER_ADDRESS = (HTTP_SERVER_HOST, HTTP_SERVER_PORT) +if IS_IPV6: + HTTP_SERVER_URL_STR = 'http://' + f'[{str(HTTP_SERVER_ADDRESS[0])}]:{str(HTTP_SERVER_ADDRESS[1])}' + "/" +else: + HTTP_SERVER_URL_STR = 'http://' + f'{str(HTTP_SERVER_ADDRESS[0])}:{str(HTTP_SERVER_ADDRESS[1])}' + "/" + +CSV_DATA = os.path.join(tempfile._get_default_tempdir(), next(tempfile._get_candidate_names())) + +def get_ch_answer(query): + host = CLICKHOUSE_HOST + if IS_IPV6: + host = f'[{host}]' + + url = os.environ.get('CLICKHOUSE_URL', 'http://{host}:{port}'.format(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT_HTTP)) + return urllib.request.urlopen(url, data=query.encode()).read().decode() + +def check_answers(query, answer): + ch_answer = get_ch_answer(query) + if ch_answer.strip() != answer.strip(): + print("FAIL on query:", query, file=sys.stderr) + print("Expected answer:", answer, file=sys.stderr) + print("Fetched answer :", ch_answer, file=sys.stderr) + raise Exception("Fail on query") + +class CSVHTTPServer(BaseHTTPRequestHandler): + def _set_headers(self): + self.send_response(200) + self.send_header('Content-type', 'text/csv') + self.end_headers() + + def do_GET(self): + self._set_headers() + self.wfile.write(('hello, world').encode()) + # with open(CSV_DATA, 'r') as fl: + # reader = csv.reader(fl, delimiter=',') + # for row in reader: + # self.wfile.write((', '.join(row) + '\n').encode()) + return + + def read_chunk(self): + msg = '' + while True: + sym = self.rfile.read(1) + if sym == '': + break + msg += sym.decode('utf-8') + if msg.endswith('\r\n'): + break + length = int(msg[:-2], 16) + if length == 0: + return '' + content = self.rfile.read(length) + self.rfile.read(2) # read sep \r\n + return content.decode('utf-8') + + def do_POST(self): + data = '' + while True: + chunk = self.read_chunk() + if not chunk: + break + data += chunk + with StringIO(data) as fl: + reader = csv.reader(fl, delimiter=',') + with open(CSV_DATA, 'a') as d: + for row in reader: + d.write(','.join(row) + '\n') + self._set_headers() + self.wfile.write(b"ok") + + def log_message(self, format, *args): + return + + +class HTTPServerV6(HTTPServer): + address_family = socket.AF_INET6 + +def start_server(requests_amount): + if IS_IPV6: + httpd = HTTPServerV6(HTTP_SERVER_ADDRESS, CSVHTTPServer) + else: + httpd = HTTPServer(HTTP_SERVER_ADDRESS, CSVHTTPServer) + + def real_func(): + for i in range(requests_amount): + httpd.handle_request() + + t = threading.Thread(target=real_func) + return t + +# test section + +def test_select(table_name="", schema="str String,numuint UInt32,numint Int32,double Float64", requests=[], answers=[], test_data=""): + with open(CSV_DATA, 'w') as f: # clear file + f.write('') + + if test_data: + with open(CSV_DATA, 'w') as f: + f.write(test_data + "\n") + + if table_name: + get_ch_answer("drop table if exists {}".format(table_name)) + get_ch_answer("create table {} ({}) engine=URL('{}', 'CSV')".format(table_name, schema, HTTP_SERVER_URL_STR)) + + for i in range(len(requests)): + tbl = table_name + if not tbl: + tbl = "url('{addr}', 'CSV', '{schema}')".format(addr=HTTP_SERVER_URL_STR, schema=schema) + check_answers(requests[i].format(tbl=tbl), answers[i]) + + if table_name: + get_ch_answer("drop table if exists {}".format(table_name)) + +def test_insert(table_name="", schema="str String,numuint UInt32,numint Int32,double Float64", requests_insert=[], requests_select=[], answers=[]): + with open(CSV_DATA, 'w') as f: # flush test file + f.write('') + + if table_name: + get_ch_answer("drop table if exists {}".format(table_name)) + get_ch_answer("create table {} ({}) engine=URL('{}', 'CSV')".format(table_name, schema, HTTP_SERVER_URL_STR)) + + for req in requests_insert: + tbl = table_name + if not tbl: + tbl = "table function url('{addr}', 'CSV', '{schema}')".format(addr=HTTP_SERVER_URL_STR, schema=schema) + get_ch_answer(req.format(tbl=tbl)) + + + for i in range(len(requests_select)): + tbl = table_name + if not tbl: + tbl = "url('{addr}', 'CSV', '{schema}')".format(addr=HTTP_SERVER_URL_STR, schema=schema) + check_answers(requests_select[i].format(tbl=tbl), answers[i]) + + if table_name: + get_ch_answer("drop table if exists {}".format(table_name)) + +def test_select_url_engine(requests=[], answers=[], test_data=""): + for i in range(len(requests)): + check_answers(requests[i], answers[i]) + +def main(): + test_data = "Hello,2,-2,7.7\nWorld,2,-5,8.8" + """ + select_only_requests = { + "select str,numuint,numint,double from {tbl}" : test_data.replace(',', '\t'), + "select numuint, count(*) from {tbl} group by numuint" : "2\t2", + "select str,numuint,numint,double from {tbl} limit 1": test_data.split("\n")[0].replace(',', '\t'), + } + + insert_requests = [ + "insert into {tbl} values('Hello',10,-2,7.7)('World',10,-5,7.7)", + "insert into {tbl} select 'Buy', number, 9-number, 9.9 from system.numbers limit 10", + ] + + select_requests = { + "select distinct numuint from {tbl} order by numuint": '\n'.join([str(i) for i in range(11)]), + "select count(*) from {tbl}": '12', + 'select double, count(*) from {tbl} group by double': "7.7\t2\n9.9\t10" + } + """ + + if IS_IPV6: + query = "select * from url('http://guest:guest@" + f'[{str(HTTP_SERVER_ADDRESS[0])}]:{str(HTTP_SERVER_ADDRESS[1])}' + "/', 'RawBLOB', 'a String')" + else: + query = "select * from url('http://guest:guest@" + f'{str(HTTP_SERVER_ADDRESS[0])}:{str(HTTP_SERVER_ADDRESS[1])}' + "/', 'RawBLOB', 'a String')" + + + + select_requests_url_auth = { + query : 'hello, world', + } + + t = start_server(len(list(select_requests_url_auth.keys()))) + t.start() + test_select(requests=list(select_requests_url_auth.keys()), answers=list(select_requests_url_auth.values()), test_data=test_data) + t.join() + print("PASSED") + + +if __name__ == "__main__": + try: + main() + except Exception as ex: + exc_type, exc_value, exc_traceback = sys.exc_info() + traceback.print_tb(exc_traceback, file=sys.stderr) + print(ex, file=sys.stderr) + sys.stderr.flush() + + os._exit(1) diff --git a/tests/queries/0_stateless/02126_url_auth.reference b/tests/queries/0_stateless/02126_url_auth.reference new file mode 100644 index 00000000000..53cdf1e9393 --- /dev/null +++ b/tests/queries/0_stateless/02126_url_auth.reference @@ -0,0 +1 @@ +PASSED diff --git a/tests/queries/0_stateless/02126_url_auth.sh b/tests/queries/0_stateless/02126_url_auth.sh new file mode 100755 index 00000000000..0013c3c1c0c --- /dev/null +++ b/tests/queries/0_stateless/02126_url_auth.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +# Tags: no-fasttest +# Tag no-fasttest: Not sure why fail even in sequential mode. Disabled for now to make some progress. + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +# We should have correct env vars from shell_config.sh to run this test + +python3 "$CURDIR"/02126_url_auth.python