Correctly append arguments

This commit is contained in:
Antonio Andelic 2023-05-04 07:56:00 +00:00
parent ee9fae6aa2
commit 8769ac2187
10 changed files with 141 additions and 17 deletions

View File

@ -41,13 +41,19 @@
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
StorageS3Cluster::StorageS3Cluster(
const Configuration & configuration_,
const StorageID & table_id_,
const ColumnsDescription & columns_,
const ConstraintsDescription & constraints_,
ContextPtr context_,
bool structure_argument_was_provided_)
bool structure_argument_was_provided_,
bool format_argument_was_provided_)
: IStorageCluster(table_id_)
, log(&Poco::Logger::get("StorageS3Cluster (" + table_id_.table_name + ")"))
, s3_configuration{configuration_}
@ -55,6 +61,7 @@ StorageS3Cluster::StorageS3Cluster(
, format_name(configuration_.format)
, compression_method(configuration_.compression_method)
, structure_argument_was_provided(structure_argument_was_provided_)
, format_argument_was_provided(format_argument_was_provided_)
{
context_->getGlobalContext()->getRemoteHostFilter().checkURL(configuration_.url.uri);
StorageInMemoryMetadata storage_metadata;
@ -89,6 +96,28 @@ void StorageS3Cluster::updateConfigurationIfChanged(ContextPtr local_context)
s3_configuration.update(local_context);
}
namespace
{
void addColumnsStructureToQueryWithS3ClusterEngine(ASTPtr & query, const String & structure, bool format_argument_was_provided, const String & function_name)
{
ASTExpressionList * expression_list = extractTableFunctionArgumentsFromSelectQuery(query);
if (!expression_list)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected SELECT query from table function {}, got '{}'", function_name, queryToString(query));
auto structure_literal = std::make_shared<ASTLiteral>(structure);
if (!format_argument_was_provided)
{
auto format_literal = std::make_shared<ASTLiteral>("auto");
expression_list->children.push_back(format_literal);
}
expression_list->children.push_back(structure_literal);
}
}
/// The code executes on initiator
Pipe StorageS3Cluster::read(
const Names & column_names,
@ -127,8 +156,8 @@ Pipe StorageS3Cluster::read(
const bool add_agg_info = processed_stage == QueryProcessingStage::WithMergeableState;
if (!structure_argument_was_provided)
addColumnsStructureToQueryWithClusterEngine(
query_to_send, StorageDictionary::generateNamesAndTypesDescription(storage_snapshot->metadata->getColumns().getAll()), 5, getName());
addColumnsStructureToQueryWithS3ClusterEngine(
query_to_send, StorageDictionary::generateNamesAndTypesDescription(storage_snapshot->metadata->getColumns().getAll()), format_argument_was_provided, getName());
RestoreQualifiedNamesVisitor::Data data;
data.distributed_table = DatabaseAndTableWithAlias(*getTableExpression(query_info.query->as<ASTSelectQuery &>(), 0));

View File

@ -32,7 +32,8 @@ public:
const ColumnsDescription & columns_,
const ConstraintsDescription & constraints_,
ContextPtr context_,
bool structure_argument_was_provided_);
bool structure_argument_was_provided_,
bool format_argument_was_provided_);
std::string getName() const override { return "S3Cluster"; }
@ -59,6 +60,7 @@ private:
NamesAndTypesList virtual_columns;
Block virtual_block;
bool structure_argument_was_provided;
bool format_argument_was_provided;
};

View File

@ -14,7 +14,7 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}
static ASTExpressionList * extractTableFunctionArgumentsFromSelectQuery(ASTPtr & query)
ASTExpressionList * extractTableFunctionArgumentsFromSelectQuery(ASTPtr & query)
{
auto * select_query = query->as<ASTSelectQuery>();
if (!select_query || !select_query->tables())

View File

@ -1,10 +1,13 @@
#pragma once
#include <Parsers/IAST.h>
#include <Parsers/ASTExpressionList.h>
namespace DB
{
ASTExpressionList * extractTableFunctionArgumentsFromSelectQuery(ASTPtr & query);
/// Add structure argument for queries with s3Cluster/hdfsCluster table function.
void addColumnsStructureToQueryWithClusterEngine(ASTPtr & query, const String & structure, size_t max_arguments, const String & function_name);

View File

@ -31,9 +31,15 @@ namespace ErrorCodes
/// This is needed to avoid copy-pase. Because s3Cluster arguments only differ in additional argument (first) - cluster name
void TableFunctionS3::parseArgumentsImpl(
const String & error_message, ASTs & args, ContextPtr context, StorageS3::Configuration & s3_configuration, bool get_format_from_file)
TableFunctionS3::ArgumentParseResult TableFunctionS3::parseArgumentsImpl(
const String & error_message,
ASTs & args,
ContextPtr context,
StorageS3::Configuration & s3_configuration,
bool get_format_from_file)
{
ArgumentParseResult result;
if (auto named_collection = tryGetNamedCollectionWithOverrides(args, context))
{
StorageS3::processNamedCollectionResult(s3_configuration, *named_collection);
@ -133,10 +139,16 @@ void TableFunctionS3::parseArgumentsImpl(
s3_configuration.url = S3::URI(checkAndGetLiteralArgument<String>(args[0], "url"));
if (args_to_idx.contains("format"))
{
s3_configuration.format = checkAndGetLiteralArgument<String>(args[args_to_idx["format"]], "format");
result.has_format_argument = true;
}
if (args_to_idx.contains("structure"))
{
s3_configuration.structure = checkAndGetLiteralArgument<String>(args[args_to_idx["structure"]], "structure");
result.has_structure_argument = true;
}
if (args_to_idx.contains("compression_method"))
s3_configuration.compression_method = checkAndGetLiteralArgument<String>(args[args_to_idx["compression_method"]], "compression_method");
@ -155,6 +167,8 @@ void TableFunctionS3::parseArgumentsImpl(
/// For DataLake table functions, we should specify default format.
if (s3_configuration.format == "auto" && get_format_from_file)
s3_configuration.format = FormatFactory::instance().getFormatFromFileName(s3_configuration.url.uri.getPath(), true);
return result;
}
void TableFunctionS3::parseArguments(const ASTPtr & ast_function, ContextPtr context)

View File

@ -43,7 +43,14 @@ public:
{
return {"_path", "_file"};
}
static void parseArgumentsImpl(
struct ArgumentParseResult
{
bool has_format_argument = false;
bool has_structure_argument = false;
};
static ArgumentParseResult parseArgumentsImpl(
const String & error_message,
ASTs & args,
ContextPtr context,

View File

@ -45,9 +45,6 @@ void TableFunctionS3Cluster::parseArguments(const ASTPtr & ast_function, Context
ASTs & args = args_func.at(0)->children;
for (auto & arg : args)
arg = evaluateConstantExpressionOrIdentifierAsLiteral(arg, context);
constexpr auto fmt_string = "The signature of table function {} could be the following:\n"
" - cluster, url\n"
" - cluster, url, format\n"
@ -61,7 +58,10 @@ void TableFunctionS3Cluster::parseArguments(const ASTPtr & ast_function, Context
if (args.size() < 2 || args.size() > 7)
throw Exception::createDeprecated(message, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
/// This arguments are always the first
/// evaluate only first argument, everything else will be done TableFunctionS3
args[0] = evaluateConstantExpressionOrIdentifierAsLiteral(args[0], context);
/// Cluster name is always the first
configuration.cluster_name = checkAndGetLiteralArgument<String>(args[0], "cluster_name");
if (!context->tryGetCluster(configuration.cluster_name))
@ -69,11 +69,11 @@ void TableFunctionS3Cluster::parseArguments(const ASTPtr & ast_function, Context
/// Just cut the first arg (cluster_name) and try to parse s3 table function arguments as is
ASTs clipped_args;
clipped_args.reserve(args.size());
clipped_args.reserve(args.size() - 1);
std::copy(args.begin() + 1, args.end(), std::back_inserter(clipped_args));
/// StorageS3ClusterConfiguration inherints from StorageS3::Configuration, so it is safe to upcast it.
TableFunctionS3::parseArgumentsImpl(message.text, clipped_args, context, static_cast<StorageS3::Configuration &>(configuration));
argument_parse_result = TableFunctionS3::parseArgumentsImpl(message.text, clipped_args, context, static_cast<StorageS3::Configuration &>(configuration));
}
@ -94,9 +94,8 @@ StoragePtr TableFunctionS3Cluster::executeImpl(
{
StoragePtr storage;
ColumnsDescription columns;
bool structure_argument_was_provided = configuration.structure != "auto";
if (structure_argument_was_provided)
if (argument_parse_result.has_structure_argument)
{
columns = parseColumnsListFromString(configuration.structure, context);
}
@ -126,7 +125,8 @@ StoragePtr TableFunctionS3Cluster::executeImpl(
columns,
ConstraintsDescription{},
context,
structure_argument_was_provided);
argument_parse_result.has_structure_argument,
argument_parse_result.has_format_argument);
}
storage->startup();

View File

@ -5,6 +5,7 @@
#if USE_AWS_S3
#include <TableFunctions/ITableFunction.h>
#include <TableFunctions/TableFunctionS3.h>
#include <Storages/StorageS3Cluster.h>
@ -52,6 +53,7 @@ protected:
mutable StorageS3Cluster::Configuration configuration;
ColumnsDescription structure_hint;
TableFunctionS3::ArgumentParseResult argument_parse_result;
};
}

View File

@ -0,0 +1,25 @@
import sys
from bottle import route, run, request, response
@route("/<_bucket>/<_path:path>")
def server(_bucket, _path):
result = (
request.headers["MyCustomHeader"]
if "MyCustomHeader" in request.headers
else "unknown"
)
response.content_type = "text/plain"
response.set_header("Content-Length", len(result))
return result
@route("/")
def ping():
response.content_type = "text/plain"
response.set_header("Content-Length", 2)
return "OK"
run(host="0.0.0.0", port=int(sys.argv[1]))

View File

@ -8,6 +8,7 @@ import time
import pytest
from helpers.cluster import ClickHouseCluster
from helpers.test_tools import TSV
from helpers.mock_servers import start_mock_servers
logging.getLogger().setLevel(logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler())
@ -49,6 +50,17 @@ def create_buckets_s3(cluster):
print(obj.object_name)
def run_s3_mocks(started_cluster):
script_dir = os.path.join(os.path.dirname(__file__), "s3_mocks")
start_mock_servers(
started_cluster,
script_dir,
[
("s3_mock.py", "resolver", "8080"),
],
)
@pytest.fixture(scope="module")
def started_cluster():
try:
@ -79,6 +91,8 @@ def started_cluster():
create_buckets_s3(cluster)
run_s3_mocks(cluster)
yield cluster
finally:
shutil.rmtree(os.path.join(SCRIPT_DIR, "data/generated/"))
@ -364,3 +378,31 @@ def test_parallel_distributed_insert_select_with_schema_inference(started_cluste
count = int(node.query("SELECT count() FROM parallel_insert_select"))
assert count == actual_count
def test_cluster_with_header(started_cluster):
node = started_cluster.instances["s0_0_0"]
assert (
node.query(
"SELECT * from s3('http://resolver:8080/bucket/key.csv', headers(MyCustomHeader = 'SomeValue'))"
)
== "SomeValue\n"
)
assert (
node.query(
"SELECT * from s3('http://resolver:8080/bucket/key.csv', headers(MyCustomHeader = 'SomeValue'), 'CSV')"
)
== "SomeValue\n"
)
assert (
node.query(
"SELECT * from s3Cluster('cluster_simple', 'http://resolver:8080/bucket/key.csv', headers(MyCustomHeader = 'SomeValue'))"
)
== "SomeValue\n"
)
assert (
node.query(
"SELECT * from s3Cluster('cluster_simple', 'http://resolver:8080/bucket/key.csv', headers(MyCustomHeader = 'SomeValue'), 'CSV')"
)
== "SomeValue\n"
)