diff --git a/src/Interpreters/ClusterProxy/executeQuery.cpp b/src/Interpreters/ClusterProxy/executeQuery.cpp index 4a637db3887..7207efd9ef6 100644 --- a/src/Interpreters/ClusterProxy/executeQuery.cpp +++ b/src/Interpreters/ClusterProxy/executeQuery.cpp @@ -488,12 +488,12 @@ void executeQueryWithParallelReplicas( shard_num = column->getUInt(0); } - const auto shard_count = not_optimized_cluster->getShardCount(); ClusterPtr new_cluster = not_optimized_cluster; /// if got valid shard_num from query initiator, then parallel replicas scope is the specified shard /// shards are numbered in order of appearance in the cluster config if (shard_num > 0) { + const auto shard_count = not_optimized_cluster->getShardCount(); if (shard_num > shard_count) throw Exception( ErrorCodes::LOGICAL_ERROR, @@ -702,6 +702,43 @@ void executeQueryWithParallelReplicasCustomKey( context, query_info.query, storage_id.getDatabaseName(), storage_id.getTableName(), /*table_function_ptr=*/nullptr); executeQueryWithParallelReplicasCustomKey(query_plan, storage_id, query_info, columns, snapshot, processed_stage, header, context); } + +bool canUseParallelReplicasOnInitiator(const ContextPtr & context) +{ + if (!context->canUseParallelReplicasOnInitiator()) + return false; + + auto cluster = context->getClusterForParallelReplicas(); + if (cluster->getShardCount() == 1) + return cluster->getShardsInfo()[0].getAllNodeCount() > 1; + + /// parallel replicas with distributed table + auto scalars = context->hasQueryContext() ? context->getQueryContext()->getScalars() : Scalars{}; + UInt64 shard_num = 0; /// shard_num is 1-based, so 0 - no shard specified + const auto it = scalars.find("_shard_num"); + if (it != scalars.end()) + { + const Block & block = it->second; + const auto & column = block.safeGetByPosition(0).column; + shard_num = column->getUInt(0); + } + if (shard_num > 0) + { + const auto shard_count = cluster->getShardCount(); + if (shard_num > shard_count) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "Shard number is greater than shard count: shard_num={} shard_count={} cluster={}", + shard_num, + shard_count, + cluster->getName()); + + return cluster->getShardsInfo().at(shard_num - 1).getAllNodeCount() > 1; + } + + return false; +} + } } diff --git a/src/Interpreters/ClusterProxy/executeQuery.h b/src/Interpreters/ClusterProxy/executeQuery.h index 2c63b7b89f0..2a21f3e8255 100644 --- a/src/Interpreters/ClusterProxy/executeQuery.h +++ b/src/Interpreters/ClusterProxy/executeQuery.h @@ -58,6 +58,8 @@ using AdditionalShardFilterGenerator = std::function; AdditionalShardFilterGenerator getShardFilterGeneratorForCustomKey(const Cluster & cluster, ContextPtr context, const ColumnsDescription & columns); +bool canUseParallelReplicasOnInitiator(const ContextPtr & context); + /// Execute a distributed query, creating a query plan, from which the query pipeline can be built. /// `stream_factory` object encapsulates the logic of creating plans for a different type of query /// (currently SELECT, DESCRIBE). diff --git a/src/Planner/PlannerJoinTree.cpp b/src/Planner/PlannerJoinTree.cpp index 4c21f0d00cc..c1eb2b48045 100644 --- a/src/Planner/PlannerJoinTree.cpp +++ b/src/Planner/PlannerJoinTree.cpp @@ -920,7 +920,7 @@ JoinTreeQueryPlan buildQueryPlanForTableExpression(QueryTreeNodePtr table_expres query_plan = std::move(query_plan_parallel_replicas); } } - else if (query_context->canUseParallelReplicasOnInitiator()) + else if (ClusterProxy::canUseParallelReplicasOnInitiator(query_context)) { // (1) find read step QueryPlan::Node * node = query_plan.getRootNode();