From 587fbdd10d93f45c3ffc5d555f240e18cd3bf12e Mon Sep 17 00:00:00 2001 From: Nikita Mikhaylov Date: Tue, 6 Apr 2021 22:18:45 +0300 Subject: [PATCH] better --- src/Processors/Sources/SourceWithProgress.h | 4 +- src/Server/TCPHandler.cpp | 75 ++------- src/Storages/StorageS3.cpp | 156 +++++++++++------- src/Storages/StorageS3.h | 12 +- src/Storages/StorageS3Distributed.cpp | 32 ++-- src/Storages/StorageS3Distributed.h | 12 +- src/Storages/TaskSupervisor.h | 41 ++--- .../TableFunctionS3Distributed.cpp | 82 ++++++--- .../TableFunctionS3Distributed.h | 2 +- 9 files changed, 208 insertions(+), 208 deletions(-) diff --git a/src/Processors/Sources/SourceWithProgress.h b/src/Processors/Sources/SourceWithProgress.h index 25ff3eacec7..3aa7a81f418 100644 --- a/src/Processors/Sources/SourceWithProgress.h +++ b/src/Processors/Sources/SourceWithProgress.h @@ -55,12 +55,12 @@ public: void setProgressCallback(const ProgressCallback & callback) final { progress_callback = callback; } void addTotalRowsApprox(size_t value) final { total_rows_approx += value; } - void work() override; - protected: /// Call this method to provide information about progress. void progress(const Progress & value); + void work() override; + private: StreamLocalLimits limits; SizeLimits leaf_limits; diff --git a/src/Server/TCPHandler.cpp b/src/Server/TCPHandler.cpp index 3341e1b9eb2..3b8823e1e86 100644 --- a/src/Server/TCPHandler.cpp +++ b/src/Server/TCPHandler.cpp @@ -471,11 +471,8 @@ bool TCPHandler::readDataNext(const size_t & poll_interval, const int & receive_ /// We are waiting for a packet from the client. Thus, every `POLL_INTERVAL` seconds check whether we need to shut down. while (true) { - { - std::lock_guard lock(buffer_mutex); - if (static_cast(*in).poll(poll_interval)) - break; - } + if (static_cast(*in).poll(poll_interval)) + break; /// Do we need to shut down? if (server.isCancelled()) @@ -494,15 +491,12 @@ bool TCPHandler::readDataNext(const size_t & poll_interval, const int & receive_ } } + /// If client disconnected. + if (in->eof()) { - std::lock_guard lock(buffer_mutex); - /// If client disconnected. - if (in->eof()) - { - LOG_INFO(log, "Client has dropped the connection, cancel the query."); - state.is_connection_closed = true; - return false; - } + LOG_INFO(log, "Client has dropped the connection, cancel the query."); + state.is_connection_closed = true; + return false; } /// We accept and process data. And if they are over, then we leave. @@ -670,6 +664,8 @@ void TCPHandler::processOrdinaryQueryWithProcessors() break; } + std::lock_guard lock(buffer_mutex); + if (after_send_progress.elapsed() / 1000 >= query_context->getSettingsRef().interactive_delay) { /// Some time passed and there is a progress. @@ -754,8 +750,6 @@ void TCPHandler::processTablesStatusRequest() void TCPHandler::receiveUnexpectedTablesStatusRequest() { - std::lock_guard lock(buffer_mutex); - TablesStatusRequest skip_request; skip_request.read(*in, client_tcp_protocol_version); @@ -764,8 +758,6 @@ void TCPHandler::receiveUnexpectedTablesStatusRequest() void TCPHandler::sendPartUUIDs() { - std::lock_guard lock(buffer_mutex); - auto uuids = query_context->getPartUUIDs()->get(); if (!uuids.empty()) { @@ -788,8 +780,6 @@ void TCPHandler::sendReadTaskRequestAssumeLocked(const String & request) void TCPHandler::sendProfileInfo(const BlockStreamProfileInfo & info) { - std::lock_guard lock(buffer_mutex); - writeVarUInt(Protocol::Server::ProfileInfo, *out); info.write(*out); out->next(); @@ -798,8 +788,6 @@ void TCPHandler::sendProfileInfo(const BlockStreamProfileInfo & info) void TCPHandler::sendTotals(const Block & totals) { - std::lock_guard lock(buffer_mutex); - if (totals) { initBlockOutput(totals); @@ -816,8 +804,6 @@ void TCPHandler::sendTotals(const Block & totals) void TCPHandler::sendExtremes(const Block & extremes) { - std::lock_guard lock(buffer_mutex); - if (extremes) { initBlockOutput(extremes); @@ -834,8 +820,6 @@ void TCPHandler::sendExtremes(const Block & extremes) bool TCPHandler::receiveProxyHeader() { - std::lock_guard lock(buffer_mutex); - if (in->eof()) { LOG_WARNING(log, "Client has not sent any data."); @@ -908,8 +892,6 @@ bool TCPHandler::receiveProxyHeader() void TCPHandler::receiveHello() { - std::lock_guard lock(buffer_mutex); - /// Receive `hello` packet. UInt64 packet_type = 0; String user; @@ -967,8 +949,6 @@ void TCPHandler::receiveHello() void TCPHandler::receiveUnexpectedHello() { - std::lock_guard lock(buffer_mutex); - UInt64 skip_uint_64; String skip_string; @@ -986,8 +966,6 @@ void TCPHandler::receiveUnexpectedHello() void TCPHandler::sendHello() { - std::lock_guard lock(buffer_mutex); - writeVarUInt(Protocol::Server::Hello, *out); writeStringBinary(DBMS_NAME, *out); writeVarUInt(DBMS_VERSION_MAJOR, *out); @@ -1006,11 +984,7 @@ void TCPHandler::sendHello() bool TCPHandler::receivePacket() { UInt64 packet_type = 0; - { - std::lock_guard lock(buffer_mutex); - readVarUInt(packet_type, *in); - } - + readVarUInt(packet_type, *in); switch (packet_type) { @@ -1058,8 +1032,6 @@ bool TCPHandler::receivePacket() void TCPHandler::receiveIgnoredPartUUIDs() { - std::lock_guard lock(buffer_mutex); - state.part_uuids = true; std::vector uuids; readVectorBinary(uuids, *in); @@ -1086,11 +1058,8 @@ String TCPHandler::receiveReadTaskResponseAssumeLocked() void TCPHandler::receiveClusterNameAndSalt() { - { - std::lock_guard lock(buffer_mutex); - readStringBinary(cluster, *in); - readStringBinary(salt, *in, 32); - } + readStringBinary(cluster, *in); + readStringBinary(salt, *in, 32); try { @@ -1114,8 +1083,6 @@ void TCPHandler::receiveClusterNameAndSalt() void TCPHandler::receiveQuery() { - std::lock_guard lock(buffer_mutex); - UInt64 stage = 0; UInt64 compression = 0; @@ -1257,8 +1224,6 @@ void TCPHandler::receiveQuery() void TCPHandler::receiveUnexpectedQuery() { - std::lock_guard lock(buffer_mutex); - UInt64 skip_uint_64; String skip_string; @@ -1287,8 +1252,6 @@ void TCPHandler::receiveUnexpectedQuery() bool TCPHandler::receiveData(bool scalar) { - std::lock_guard lock(buffer_mutex); - initBlockInput(); /// The name of the temporary table for writing data, default to empty string @@ -1348,8 +1311,6 @@ bool TCPHandler::receiveData(bool scalar) void TCPHandler::receiveUnexpectedData() { - std::lock_guard lock(buffer_mutex); - String skip_external_table_name; readStringBinary(skip_external_table_name, *in); @@ -1488,8 +1449,6 @@ bool TCPHandler::isQueryCancelled() void TCPHandler::sendData(const Block & block) { - std::lock_guard lock(buffer_mutex); - initBlockOutput(block); auto prev_bytes_written_out = out->count(); @@ -1552,8 +1511,6 @@ void TCPHandler::sendLogData(const Block & block) void TCPHandler::sendTableColumns(const ColumnsDescription & columns) { - std::lock_guard lock(buffer_mutex); - writeVarUInt(Protocol::Server::TableColumns, *out); /// Send external table name (empty name is the main table) @@ -1565,8 +1522,6 @@ void TCPHandler::sendTableColumns(const ColumnsDescription & columns) void TCPHandler::sendException(const Exception & e, bool with_stack_trace) { - std::lock_guard lock(buffer_mutex); - writeVarUInt(Protocol::Server::Exception, *out); writeException(e, *out, with_stack_trace); out->next(); @@ -1575,8 +1530,6 @@ void TCPHandler::sendException(const Exception & e, bool with_stack_trace) void TCPHandler::sendEndOfStream() { - std::lock_guard lock(buffer_mutex); - state.sent_all_data = true; writeVarUInt(Protocol::Server::EndOfStream, *out); out->next(); @@ -1591,8 +1544,6 @@ void TCPHandler::updateProgress(const Progress & value) void TCPHandler::sendProgress() { - std::lock_guard lock(buffer_mutex); - writeVarUInt(Protocol::Server::Progress, *out); auto increment = state.progress.fetchAndResetPiecewiseAtomically(); increment.write(*out, client_tcp_protocol_version); @@ -1602,8 +1553,6 @@ void TCPHandler::sendProgress() void TCPHandler::sendLogs() { - std::lock_guard lock(buffer_mutex); - if (!state.logs_queue) return; diff --git a/src/Storages/StorageS3.cpp b/src/Storages/StorageS3.cpp index 678a6cc3270..7d77d420584 100644 --- a/src/Storages/StorageS3.cpp +++ b/src/Storages/StorageS3.cpp @@ -46,6 +46,95 @@ namespace ErrorCodes extern const int S3_ERROR; } +class StorageS3Source::DisclosedGlobIterator::Impl +{ + +public: + Impl(Aws::S3::S3Client & client_, const S3::URI & globbed_uri_) + : client(client_), globbed_uri(globbed_uri_) { + + if (globbed_uri.bucket.find_first_of("*?{") != globbed_uri.bucket.npos) + throw Exception("Expression can not have wildcards inside bucket name", ErrorCodes::UNEXPECTED_EXPRESSION); + + const String key_prefix = globbed_uri.key.substr(0, globbed_uri.key.find_first_of("*?{")); + + if (key_prefix.size() == globbed_uri.key.size()) + buffer.emplace_back(globbed_uri.key); + + request.SetBucket(globbed_uri.bucket); + request.SetPrefix(key_prefix); + + matcher = std::make_unique(makeRegexpPatternFromGlobs(globbed_uri.key)); + + /// Don't forget about iterator invalidation + buffer_iter = buffer.begin(); + } + + std::optional next() + { + if (buffer_iter != buffer.end()) + { + auto answer = *buffer_iter; + ++buffer_iter; + return answer; + } + + if (is_finished) + return std::nullopt; // Or throw? + + fillInternalBuffer(); + + return next(); + } + +private: + + void fillInternalBuffer() + { + buffer.clear(); + + outcome = client.ListObjectsV2(request); + if (!outcome.IsSuccess()) + throw Exception(ErrorCodes::S3_ERROR, "Could not list objects in bucket {} with prefix {}, S3 exception: {}, message: {}", + quoteString(request.GetBucket()), quoteString(request.GetPrefix()), + backQuote(outcome.GetError().GetExceptionName()), quoteString(outcome.GetError().GetMessage())); + + const auto & result_batch = outcome.GetResult().GetContents(); + + buffer.reserve(result_batch.size()); + for (const auto & row : result_batch) + { + String key = row.GetKey(); + if (re2::RE2::FullMatch(key, *matcher)) + buffer.emplace_back(std::move(key)); + } + /// Set iterator only after the whole batch is processed + buffer_iter = buffer.begin(); + + request.SetContinuationToken(outcome.GetResult().GetNextContinuationToken()); + + /// It returns false when all objects were returned + is_finished = !outcome.GetResult().GetIsTruncated(); + } + + Strings buffer; + Strings::iterator buffer_iter; + Aws::S3::S3Client client; + S3::URI globbed_uri; + Aws::S3::Model::ListObjectsV2Request request; + Aws::S3::Model::ListObjectsV2Outcome outcome; + std::unique_ptr matcher; + bool is_finished{false}; +}; + +StorageS3Source::DisclosedGlobIterator::DisclosedGlobIterator(Aws::S3::S3Client & client_, const S3::URI & globbed_uri_) + : pimpl(std::make_unique(client_, globbed_uri_)) {} + +std::optional StorageS3Source::DisclosedGlobIterator::next() +{ + return pimpl->next(); +} + Block StorageS3Source::getHeader(Block sample_block, bool with_path_column, bool with_file_column) { @@ -209,62 +298,6 @@ StorageS3::StorageS3( } -/* "Recursive" directory listing with matched paths as a result. - * Have the same method in StorageFile. - */ -Strings StorageS3::listFilesWithRegexpMatching(Aws::S3::S3Client & client, const S3::URI & globbed_uri) -{ - if (globbed_uri.bucket.find_first_of("*?{") != globbed_uri.bucket.npos) - { - throw Exception("Expression can not have wildcards inside bucket name", ErrorCodes::UNEXPECTED_EXPRESSION); - } - - const String key_prefix = globbed_uri.key.substr(0, globbed_uri.key.find_first_of("*?{")); - if (key_prefix.size() == globbed_uri.key.size()) - { - return {globbed_uri.key}; - } - - Aws::S3::Model::ListObjectsV2Request request; - request.SetBucket(globbed_uri.bucket); - request.SetPrefix(key_prefix); - - re2::RE2 matcher(makeRegexpPatternFromGlobs(globbed_uri.key)); - Strings result; - Aws::S3::Model::ListObjectsV2Outcome outcome; - int page = 0; - do - { - ++page; - outcome = client.ListObjectsV2(request); - if (!outcome.IsSuccess()) - { - if (page > 1) - throw Exception(ErrorCodes::S3_ERROR, "Could not list objects in bucket {} with prefix {}, page {}, S3 exception: {}, message: {}", - quoteString(request.GetBucket()), quoteString(request.GetPrefix()), page, - backQuote(outcome.GetError().GetExceptionName()), quoteString(outcome.GetError().GetMessage())); - - throw Exception(ErrorCodes::S3_ERROR, "Could not list objects in bucket {} with prefix {}, S3 exception: {}, message: {}", - quoteString(request.GetBucket()), quoteString(request.GetPrefix()), - backQuote(outcome.GetError().GetExceptionName()), quoteString(outcome.GetError().GetMessage())); - } - - for (const auto & row : outcome.GetResult().GetContents()) - { - String key = row.GetKey(); - std::cout << "KEY " << key << std::endl; - if (re2::RE2::FullMatch(key, matcher)) - result.emplace_back(std::move(key)); - } - - request.SetContinuationToken(outcome.GetResult().GetNextContinuationToken()); - } - while (outcome.GetResult().GetIsTruncated()); - - return result; -} - - Pipe StorageS3::read( const Names & column_names, const StorageMetadataPtr & metadata_snapshot, @@ -287,7 +320,12 @@ Pipe StorageS3::read( need_file_column = true; } - for (const String & key : listFilesWithRegexpMatching(*client_auth.client, client_auth.uri)) + /// Iterate through disclosed globs and make a source for each file + StorageS3Source::DisclosedGlobIterator glob_iterator(*client_auth.client, client_auth.uri); + /// TODO: better to put first num_streams keys into pipeline + /// and put others dynamically in runtime + while (auto key = glob_iterator.next()) + { pipes.emplace_back(std::make_shared( need_path_column, need_file_column, @@ -300,8 +338,8 @@ Pipe StorageS3::read( chooseCompressionMethod(client_auth.uri.key, compression_method), client_auth.client, client_auth.uri.bucket, - key)); - + key.value())); + } auto pipe = Pipe::unitePipes(std::move(pipes)); // It's possible to have many buckets read from s3, resize(num_streams) might open too many handles at the same time. // Using narrowPipe instead. diff --git a/src/Storages/StorageS3.h b/src/Storages/StorageS3.h index c47a88e35d9..6e9202abb6f 100644 --- a/src/Storages/StorageS3.h +++ b/src/Storages/StorageS3.h @@ -31,6 +31,17 @@ class StorageS3Source : public SourceWithProgress { public: + class DisclosedGlobIterator + { + public: + DisclosedGlobIterator(Aws::S3::S3Client &, const S3::URI &); + std::optional next(); + private: + class Impl; + /// shared_ptr to have copy constructor + std::shared_ptr pimpl; + }; + static Block getHeader(Block sample_block, bool with_path_column, bool with_file_column); StorageS3Source( @@ -125,7 +136,6 @@ private: String compression_method; String name; - static Strings listFilesWithRegexpMatching(Aws::S3::S3Client & client, const S3::URI & globbed_uri); static void updateClientAndAuthSettings(ContextPtr, ClientAuthentificaiton &); }; diff --git a/src/Storages/StorageS3Distributed.cpp b/src/Storages/StorageS3Distributed.cpp index 12a1f146ad5..2a257ed922e 100644 --- a/src/Storages/StorageS3Distributed.cpp +++ b/src/Storages/StorageS3Distributed.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -180,8 +181,7 @@ private: StorageS3Distributed::StorageS3Distributed( - IAST::Hash tree_hash_, - const String & address_hash_or_filename_, + const String & filename_, const String & access_key_id_, const String & secret_access_key_, const StorageID & table_id_, @@ -193,8 +193,7 @@ StorageS3Distributed::StorageS3Distributed( const Context & context_, const String & compression_method_) : IStorage(table_id_) - , tree_hash(tree_hash_) - , address_hash_or_filename(address_hash_or_filename_) + , filename(filename_) , cluster_name(cluster_name_) , cluster(context_.getCluster(cluster_name)->getClusterWithReplicasAsShards(context_.getSettings())) , format_name(format_name_) @@ -268,28 +267,17 @@ Pipe StorageS3Distributed::read( for (const auto & node : replicas) { connections.emplace_back(std::make_shared( - /*host=*/node.host_name, - /*port=*/node.port, - /*default_database=*/context.getGlobalContext().getCurrentDatabase(), - /*user=*/node.user, - /*password=*/node.password, - /*cluster=*/node.cluster, - /*cluster_secret=*/node.cluster_secret, + node.host_name, node.port, context.getGlobalContext().getCurrentDatabase(), + node.user, node.password, node.cluster, node.cluster_secret, "StorageS3DistributedInititiator", Protocol::Compression::Disable, Protocol::Secure::Disable )); - auto stream = std::make_shared( - /*connection=*/*connections.back(), - /*query=*/queryToString(query_info.query), - /*header=*/header, - /*context=*/context, - /*throttler=*/nullptr, - /*scalars*/scalars, - /*external_tables*/Tables(), - /*stage*/processed_stage - ); - pipes.emplace_back(std::make_shared(std::move(stream))); + + auto remote_query_executor = std::make_shared( + *connections.back(), queryToString(query_info.query), header, context, /*throttler=*/nullptr, scalars, Tables(), processed_stage); + + pipes.emplace_back(createRemoteSourcePipe(remote_query_executor, false, false, false, false)); } } diff --git a/src/Storages/StorageS3Distributed.h b/src/Storages/StorageS3Distributed.h index 23c3230c6c6..13e28d1a7aa 100644 --- a/src/Storages/StorageS3Distributed.h +++ b/src/Storages/StorageS3Distributed.h @@ -46,11 +46,16 @@ public: size_t /*max_block_size*/, unsigned /*num_streams*/) override; + QueryProcessingStage::Enum getQueryProcessingStage(const Context &, QueryProcessingStage::Enum /*to_stage*/, SelectQueryInfo &) const override + { + return QueryProcessingStage::Enum::WithMergeableState; + } + + NamesAndTypesList getVirtuals() const override; protected: StorageS3Distributed( - IAST::Hash tree_hash_, - const String & address_hash_or_filename_, + const String & filename_, const String & access_key_id_, const String & secret_access_key_, const StorageID & table_id_, @@ -65,8 +70,7 @@ protected: private: /// Connections from initiator to other nodes std::vector> connections; - IAST::Hash tree_hash; - String address_hash_or_filename; + String filename; std::string cluster_name; ClusterPtr cluster; diff --git a/src/Storages/TaskSupervisor.h b/src/Storages/TaskSupervisor.h index 7de0081d048..20e2489d120 100644 --- a/src/Storages/TaskSupervisor.h +++ b/src/Storages/TaskSupervisor.h @@ -21,34 +21,15 @@ using Task = std::string; using Tasks = std::vector; using TasksIterator = Tasks::iterator; -class S3NextTaskResolver +struct ReadTaskResolver { -public: - S3NextTaskResolver(QueryId query_id, Tasks && all_tasks) - : id(query_id) - , tasks(all_tasks) - , current(tasks.begin()) - {} - - std::string next() - { - auto it = current; - ++current; - return it == tasks.end() ? "" : *it; - } - - std::string getId() - { - return id; - } - -private: - QueryId id; - Tasks tasks; - TasksIterator current; + ReadTaskResolver(String name_, std::function callback_) + : name(name_), callback(callback_) {} + String name; + std::function callback; }; -using S3NextTaskResolverPtr = std::shared_ptr; +using ReadTaskResolverPtr = std::unique_ptr; class TaskSupervisor { @@ -57,13 +38,13 @@ public: TaskSupervisor() = default; - void registerNextTaskResolver(S3NextTaskResolverPtr resolver) + void registerNextTaskResolver(ReadTaskResolverPtr resolver) { std::lock_guard lock(mutex); - auto & target = dict[resolver->getId()]; + auto & target = dict[resolver->name]; if (target) throw Exception(fmt::format("NextTaskResolver with name {} is already registered for query {}", - target->getId(), resolver->getId()), ErrorCodes::LOGICAL_ERROR); + target->name, resolver->name), ErrorCodes::LOGICAL_ERROR); target = std::move(resolver); } @@ -74,14 +55,14 @@ public: auto it = dict.find(id); if (it == dict.end()) return ""; - auto answer = it->second->next(); + auto answer = it->second->callback(); if (answer.empty()) dict.erase(it); return answer; } private: - using ResolverDict = std::unordered_map; + using ResolverDict = std::unordered_map; ResolverDict dict; std::mutex mutex; }; diff --git a/src/TableFunctions/TableFunctionS3Distributed.cpp b/src/TableFunctions/TableFunctionS3Distributed.cpp index a5b9012e7a2..814b2586242 100644 --- a/src/TableFunctions/TableFunctionS3Distributed.cpp +++ b/src/TableFunctions/TableFunctionS3Distributed.cpp @@ -1,3 +1,5 @@ +#include +#include #include #include #include "DataStreams/RemoteBlockInputStream.h" @@ -11,6 +13,8 @@ #if USE_AWS_S3 + +#include #include #include #include @@ -29,8 +33,10 @@ namespace ErrorCodes { extern const int LOGICAL_ERROR; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int UNEXPECTED_EXPRESSION; } + void TableFunctionS3Distributed::parseArguments(const ASTPtr & ast_function, const Context & context) { /// Parse args @@ -41,32 +47,51 @@ void TableFunctionS3Distributed::parseArguments(const ASTPtr & ast_function, con ASTs & args = args_func.at(0)->children; + const auto message = fmt::format( + "The signature of table function {} could be the following:\n" \ + " - cluster, url, format, structure\n" \ + " - cluster, url, format, structure, compression_method\n" \ + " - cluster, url, access_key_id, secret_access_key, format, structure\n" \ + " - cluster, url, access_key_id, secret_access_key, format, structure, compression_method", + getName()); + if (args.size() < 4 || args.size() > 7) - throw Exception("Table function '" + getName() + "' requires 4 to 7 arguments: cluster, url," + - "[access_key_id, secret_access_key,] format, structure and [compression_method].", - ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + throw Exception(message, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); for (auto & arg : args) arg = evaluateConstantExpressionOrIdentifierAsLiteral(arg, context); cluster_name = args[0]->as().value.safeGet(); - filename_or_initiator_hash = args[1]->as().value.safeGet(); + filename = args[1]->as().value.safeGet(); - if (args.size() < 5) + if (args.size() == 4) { format = args[2]->as().value.safeGet(); structure = args[3]->as().value.safeGet(); + } + else if (args.size() == 5) + { + format = args[2]->as().value.safeGet(); + structure = args[3]->as().value.safeGet(); + compression_method = args[4]->as().value.safeGet(); } - else + else if (args.size() == 6) { access_key_id = args[2]->as().value.safeGet(); secret_access_key = args[3]->as().value.safeGet(); format = args[4]->as().value.safeGet(); structure = args[5]->as().value.safeGet(); } - - if (args.size() == 5 || args.size() == 7) - compression_method = args.back()->as().value.safeGet(); + else if (args.size() == 7) + { + access_key_id = args[2]->as().value.safeGet(); + secret_access_key = args[3]->as().value.safeGet(); + format = args[4]->as().value.safeGet(); + structure = args[5]->as().value.safeGet(); + compression_method = args[4]->as().value.safeGet(); + } + else + throw Exception(message, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); } @@ -76,7 +101,7 @@ ColumnsDescription TableFunctionS3Distributed::getActualTableStructure(const Con } StoragePtr TableFunctionS3Distributed::executeImpl( - const ASTPtr & ast_function, const Context & context, + const ASTPtr & /*filename*/, const Context & context, const std::string & table_name, ColumnsDescription /*cached_columns*/) const { UInt64 max_connections = context.getSettingsRef().s3_max_connections; @@ -84,32 +109,28 @@ StoragePtr TableFunctionS3Distributed::executeImpl( /// Initiator specific logic while (context.getClientInfo().query_kind == ClientInfo::QueryKind::INITIAL_QUERY) { - auto poco_uri = Poco::URI{filename_or_initiator_hash}; - - /// This is needed, because secondary query on local replica has the same query-id - if (poco_uri.getHost().empty() || poco_uri.getPort() == 0) - break; - + auto poco_uri = Poco::URI{filename}; S3::URI s3_uri(poco_uri); StorageS3::ClientAuthentificaiton client_auth{s3_uri, access_key_id, secret_access_key, max_connections, {}, {}}; StorageS3::updateClientAndAuthSettings(context, client_auth); + StorageS3Source::DisclosedGlobIterator iterator(*client_auth.client, client_auth.uri); - auto lists = StorageS3::listFilesWithRegexpMatching(*client_auth.client, client_auth.uri); - Strings tasks; - tasks.reserve(lists.size()); + auto callback = [endpoint = client_auth.uri.endpoint, bucket = client_auth.uri.bucket, iterator = std::move(iterator)]() mutable -> String + { + if (auto value = iterator.next()) + return endpoint + '/' + bucket + '/' + *value; + return {}; + }; - for (auto & value : lists) - tasks.emplace_back(client_auth.uri.endpoint + '/' + client_auth.uri.bucket + '/' + value); - - /// Register resolver, which will give other nodes a task to execute - context.getReadTaskSupervisor()->registerNextTaskResolver(std::make_unique(context.getCurrentQueryId(), std::move(tasks))); + /// Register resolver, which will give other nodes a task std::make_unique + context.getReadTaskSupervisor()->registerNextTaskResolver( + std::make_unique(context.getCurrentQueryId(), std::move(callback))); break; } StoragePtr storage = StorageS3Distributed::create( - ast_function->getTreeHash(), - filename_or_initiator_hash, + filename, access_key_id, secret_access_key, StorageID(getDatabaseName(), table_name), @@ -137,6 +158,15 @@ void registerTableFunctionCOSDistributed(TableFunctionFactory & factory) factory.registerFunction(); } + +NamesAndTypesList StorageS3Distributed::getVirtuals() const +{ + return NamesAndTypesList{ + {"_path", std::make_shared()}, + {"_file", std::make_shared()} + }; +} + } #endif diff --git a/src/TableFunctions/TableFunctionS3Distributed.h b/src/TableFunctions/TableFunctionS3Distributed.h index a2dd526ab05..ff94eaa83e3 100644 --- a/src/TableFunctions/TableFunctionS3Distributed.h +++ b/src/TableFunctions/TableFunctionS3Distributed.h @@ -44,7 +44,7 @@ protected: void parseArguments(const ASTPtr & ast_function, const Context & context) override; String cluster_name; - String filename_or_initiator_hash; + String filename; String format; String structure; String access_key_id;