add test file for url basic auth

This commit is contained in:
michael1589 2021-11-27 23:54:33 +08:00
parent dbfe637f7b
commit 26a117c4c1
3 changed files with 238 additions and 0 deletions

View File

@ -0,0 +1,226 @@
#!/usr/bin/env python3
import socket
import csv
import sys
import tempfile
import threading
import os
import traceback
import urllib.request
import subprocess
from io import StringIO
from http.server import BaseHTTPRequestHandler, HTTPServer
def is_ipv6(host):
try:
socket.inet_aton(host)
return False
except:
return True
def get_local_port(host, ipv6):
if ipv6:
family = socket.AF_INET6
else:
family = socket.AF_INET
with socket.socket(family) as fd:
fd.bind((host, 0))
return fd.getsockname()[1]
CLICKHOUSE_HOST = os.environ.get('CLICKHOUSE_HOST', '127.0.0.1')
CLICKHOUSE_PORT_HTTP = os.environ.get('CLICKHOUSE_PORT_HTTP', '8123')
#####################################################################################
# This test starts an HTTP server and serves data to clickhouse url-engine based table.
# In order for it to work ip+port of http server (given below) should be
# accessible from clickhouse server.
#####################################################################################
# IP-address of this host accessible from the outside world. Get the first one
HTTP_SERVER_HOST = subprocess.check_output(['hostname', '-i']).decode('utf-8').strip().split()[0]
IS_IPV6 = is_ipv6(HTTP_SERVER_HOST)
HTTP_SERVER_PORT = get_local_port(HTTP_SERVER_HOST, IS_IPV6)
# IP address and port of the HTTP server started from this script.
HTTP_SERVER_ADDRESS = (HTTP_SERVER_HOST, HTTP_SERVER_PORT)
if IS_IPV6:
HTTP_SERVER_URL_STR = 'http://' + f'[{str(HTTP_SERVER_ADDRESS[0])}]:{str(HTTP_SERVER_ADDRESS[1])}' + "/"
else:
HTTP_SERVER_URL_STR = 'http://' + f'{str(HTTP_SERVER_ADDRESS[0])}:{str(HTTP_SERVER_ADDRESS[1])}' + "/"
CSV_DATA = os.path.join(tempfile._get_default_tempdir(), next(tempfile._get_candidate_names()))
def get_ch_answer(query):
host = CLICKHOUSE_HOST
if IS_IPV6:
host = f'[{host}]'
url = os.environ.get('CLICKHOUSE_URL', 'http://{host}:{port}'.format(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT_HTTP))
return urllib.request.urlopen(url, data=query.encode()).read().decode()
def check_answers(query, answer):
ch_answer = get_ch_answer(query)
if ch_answer.strip() != answer.strip():
print("FAIL on query:", query, file=sys.stderr)
print("Expected answer:", answer, file=sys.stderr)
print("Fetched answer :", ch_answer, file=sys.stderr)
raise Exception("Fail on query")
class CSVHTTPServer(BaseHTTPRequestHandler):
def _set_headers(self):
self.send_response(200)
self.send_header('Content-type', 'text/csv')
self.end_headers()
def do_GET(self):
self._set_headers()
with open(CSV_DATA, 'r') as fl:
reader = csv.reader(fl, delimiter=',')
for row in reader:
self.wfile.write((', '.join(row) + '\n').encode())
return
def read_chunk(self):
msg = ''
while True:
sym = self.rfile.read(1)
if sym == '':
break
msg += sym.decode('utf-8')
if msg.endswith('\r\n'):
break
length = int(msg[:-2], 16)
if length == 0:
return ''
content = self.rfile.read(length)
self.rfile.read(2) # read sep \r\n
return content.decode('utf-8')
def do_POST(self):
data = ''
while True:
chunk = self.read_chunk()
if not chunk:
break
data += chunk
with StringIO(data) as fl:
reader = csv.reader(fl, delimiter=',')
with open(CSV_DATA, 'a') as d:
for row in reader:
d.write(','.join(row) + '\n')
self._set_headers()
self.wfile.write(b"ok")
def log_message(self, format, *args):
return
class HTTPServerV6(HTTPServer):
address_family = socket.AF_INET6
def start_server(requests_amount):
if IS_IPV6:
httpd = HTTPServerV6(HTTP_SERVER_ADDRESS, CSVHTTPServer)
else:
httpd = HTTPServer(HTTP_SERVER_ADDRESS, CSVHTTPServer)
def real_func():
for i in range(requests_amount):
httpd.handle_request()
t = threading.Thread(target=real_func)
return t
# test section
def test_select(table_name="", schema="str String,numuint UInt32,numint Int32,double Float64", requests=[], answers=[], test_data=""):
with open(CSV_DATA, 'w') as f: # clear file
f.write('')
if test_data:
with open(CSV_DATA, 'w') as f:
f.write(test_data + "\n")
if table_name:
get_ch_answer("drop table if exists {}".format(table_name))
get_ch_answer("create table {} ({}) engine=URL('{}', 'CSV')".format(table_name, schema, HTTP_SERVER_URL_STR))
for i in range(len(requests)):
tbl = table_name
if not tbl:
tbl = "url('{addr}', 'CSV', '{schema}')".format(addr=HTTP_SERVER_URL_STR, schema=schema)
check_answers(requests[i].format(tbl=tbl), answers[i])
if table_name:
get_ch_answer("drop table if exists {}".format(table_name))
def test_insert(table_name="", schema="str String,numuint UInt32,numint Int32,double Float64", requests_insert=[], requests_select=[], answers=[]):
with open(CSV_DATA, 'w') as f: # flush test file
f.write('')
if table_name:
get_ch_answer("drop table if exists {}".format(table_name))
get_ch_answer("create table {} ({}) engine=URL('{}', 'CSV')".format(table_name, schema, HTTP_SERVER_URL_STR))
for req in requests_insert:
tbl = table_name
if not tbl:
tbl = "table function url('{addr}', 'CSV', '{schema}')".format(addr=HTTP_SERVER_URL_STR, schema=schema)
get_ch_answer(req.format(tbl=tbl))
for i in range(len(requests_select)):
tbl = table_name
if not tbl:
tbl = "url('{addr}', 'CSV', '{schema}')".format(addr=HTTP_SERVER_URL_STR, schema=schema)
check_answers(requests_select[i].format(tbl=tbl), answers[i])
if table_name:
get_ch_answer("drop table if exists {}".format(table_name))
def test_select_url_engine(requests=[], answers=[], test_data=""):
for i in range(len(requests)):
check_answers(requests[i], answers[i])
def main():
test_data = "Hello,2,-2,7.7\nWorld,2,-5,8.8"
select_only_requests = {
"select str,numuint,numint,double from {tbl}" : test_data.replace(',', '\t'),
"select numuint, count(*) from {tbl} group by numuint" : "2\t2",
"select str,numuint,numint,double from {tbl} limit 1": test_data.split("\n")[0].replace(',', '\t'),
}
insert_requests = [
"insert into {tbl} values('Hello',10,-2,7.7)('World',10,-5,7.7)",
"insert into {tbl} select 'Buy', number, 9-number, 9.9 from system.numbers limit 10",
]
select_requests = {
"select distinct numuint from {tbl} order by numuint": '\n'.join([str(i) for i in range(11)]),
"select count(*) from {tbl}": '12',
'select double, count(*) from {tbl} group by double': "7.7\t2\n9.9\t10"
}
select_requests_url_auth = {
"select * from url('http://guest:guest@127.0.0.1:4443/', 'RawBLOB', 'a String')": test_data.replace(',', '\t'),
}
t = start_server(len(select_only_requests) * 2 + (len(insert_requests) + len(select_requests)) * 2)
t.start()
test_select(requests=list(select_requests_url_auth.keys()), answers=list(select_requests_url_auth.values()), test_data=test_data)
t.join()
print("PASSED")
if __name__ == "__main__":
try:
main()
except Exception as ex:
exc_type, exc_value, exc_traceback = sys.exc_info()
traceback.print_tb(exc_traceback, file=sys.stderr)
print(ex, file=sys.stderr)
sys.stderr.flush()
os._exit(1)

View File

@ -0,0 +1 @@
PASSED

View File

@ -0,0 +1,11 @@
#!/usr/bin/env bash
# Tags: no-fasttest
# Tag no-fasttest: Not sure why fail even in sequential mode. Disabled for now to make some progress.
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=../shell_config.sh
. "$CURDIR"/../shell_config.sh
# We should have correct env vars from shell_config.sh to run this test
python3 "$CURDIR"/02126_url_auth.python