Fixed bug with S3 URLs containing + symbol, data with such keys could not be read previously.

This commit is contained in:
Vladimir Chebotarev 2021-05-04 09:25:33 +03:00
parent 08f10dced0
commit 524113f497
3 changed files with 81 additions and 1 deletions

View File

@ -176,7 +176,12 @@ void PocoHTTPClient::makeRequestInternal(
Poco::Net::HTTPRequest poco_request(Poco::Net::HTTPRequest::HTTP_1_1);
poco_request.setURI(target_uri.getPathAndQuery());
/// N.B. Aws::Http::URI will encode uri in appropriate way for AWS S3 server.
/// Poco::URI does that in incompatible way which can lead to double decoding.
/// For example, `+` symbol will not be converted to `%2B` by Poco and would
/// be received as space symbol.
Aws::Http::URI aws_target_uri(uri);
poco_request.setURI(aws_target_uri.GetPath() + aws_target_uri.GetQueryString());
switch (request.GetMethod())
{

View File

@ -0,0 +1,33 @@
import http.server
import sys
class RequestHandler(http.server.BaseHTTPRequestHandler):
def do_HEAD(self):
if self.path.startswith("/get-my-path/"):
self.send_response(200)
self.send_header("Content-Type", "text/plain")
self.end_headers()
elif self.path == "/":
self.send_response(200)
self.send_header("Content-Type", "text/plain")
self.end_headers()
else:
self.send_response(404)
self.send_header("Content-Type", "text/plain")
self.end_headers()
def do_GET(self):
self.do_HEAD()
if self.path.startswith("/get-my-path/"):
self.wfile.write(b'/' + self.path.split('/', maxsplit=2)[2].encode())
elif self.path == "/":
self.wfile.write(b"OK")
httpd = http.server.HTTPServer(("0.0.0.0", int(sys.argv[1])), RequestHandler)
httpd.serve_forever()

View File

@ -146,6 +146,47 @@ def test_put(cluster, maybe_auth, positive, compression):
assert values_csv == get_s3_file_content(cluster, bucket, filename)
@pytest.mark.parametrize("special", [
"space",
"plus"
])
def test_get_file_with_special(cluster, special):
symbol = {"space": " ", "plus": "+"}[special]
urlsafe_symbol = {"space": "%20", "plus": "%2B"}[special]
auth = "'minio','minio123',"
bucket = cluster.minio_restricted_bucket
instance = cluster.instances["dummy"]
table_format = "column1 UInt32, column2 UInt32, column3 UInt32"
values = [[12549, 2463, 19893], [64021, 38652, 66703], [81611, 39650, 83516], [11079, 59507, 61546], [51764, 69952, 6876], [41165, 90293, 29095], [40167, 78432, 48309], [81629, 81327, 11855], [55852, 21643, 98507], [6738, 54643, 41155]]
values_csv = ('\n'.join((','.join(map(str, row)) for row in values)) + '\n').encode()
filename = f"get_file_with_{special}_{symbol}two.csv"
put_s3_file_content(cluster, bucket, filename, values_csv)
get_query = f"SELECT * FROM s3('http://{cluster.minio_host}:{cluster.minio_port}/{bucket}/get_file_with_{special}_{urlsafe_symbol}two.csv', {auth}'CSV', '{table_format}') FORMAT TSV"
assert [list(map(int, l.split())) for l in run_query(instance, get_query).splitlines()] == values
get_query = f"SELECT * FROM s3('http://{cluster.minio_host}:{cluster.minio_port}/{bucket}/get_file_with_{special}*.csv', {auth}'CSV', '{table_format}') FORMAT TSV"
assert [list(map(int, l.split())) for l in run_query(instance, get_query).splitlines()] == values
get_query = f"SELECT * FROM s3('http://{cluster.minio_host}:{cluster.minio_port}/{bucket}/get_file_with_{special}_{urlsafe_symbol}*.csv', {auth}'CSV', '{table_format}') FORMAT TSV"
assert [list(map(int, l.split())) for l in run_query(instance, get_query).splitlines()] == values
@pytest.mark.parametrize("special", [
"space",
"plus",
"plus2"
])
def test_get_path_with_special(cluster, special):
symbol = {"space": "%20", "plus": "%2B", "plus2": "%2B"}[special]
safe_symbol = {"space": "%20", "plus": "+", "plus2": "%2B"}[special]
auth = "'minio','minio123',"
table_format = "column1 String"
instance = cluster.instances["dummy"]
get_query = f"SELECT * FROM s3('http://resolver:8082/get-my-path/{safe_symbol}.csv', {auth}'CSV', '{table_format}') FORMAT TSV"
assert run_query(instance, get_query).splitlines() == [f"/{symbol}.csv"]
# Test put no data to S3.
@pytest.mark.parametrize("auth", [
"'minio','minio123',"
@ -389,6 +430,7 @@ def run_s3_mocks(cluster):
mocks = (
("mock_s3.py", "resolver", "8080"),
("unstable_server.py", "resolver", "8081"),
("echo.py", "resolver", "8082"),
)
for mock_filename, container, port in mocks:
container_id = cluster.get_container_id(container)