From 479d4fa99102b53d5e49493fc66ae27ca2144b43 Mon Sep 17 00:00:00 2001 From: Amos Bird Date: Sat, 31 Jul 2021 15:45:26 +0800 Subject: [PATCH] Add shardNum() and shardCount() functions --- src/Functions/getScalar.cpp | 79 +++++++++++++++ .../ClusterProxy/IStreamFactory.h | 3 +- .../ClusterProxy/SelectStreamFactory.cpp | 97 +++---------------- .../ClusterProxy/SelectStreamFactory.h | 3 +- .../ClusterProxy/executeQuery.cpp | 12 ++- src/Interpreters/Context.cpp | 14 +++ src/Interpreters/Context.h | 4 + src/Interpreters/ExpressionAnalyzer.cpp | 2 +- .../IInterpreterUnionOrSelectQuery.h | 9 ++ src/Interpreters/SelectQueryOptions.h | 14 +++ src/Processors/QueryPlan/ReadFromRemote.cpp | 21 ++-- src/Processors/QueryPlan/ReadFromRemote.h | 4 +- .../01860_Distributed__shard_num_GROUP_BY.sql | 4 +- .../02001_shard_num_shard_count.sql | 3 + 14 files changed, 168 insertions(+), 101 deletions(-) create mode 100644 tests/queries/0_stateless/02001_shard_num_shard_count.sql diff --git a/src/Functions/getScalar.cpp b/src/Functions/getScalar.cpp index a29abd257e7..ec636ae89f5 100644 --- a/src/Functions/getScalar.cpp +++ b/src/Functions/getScalar.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -60,11 +61,89 @@ private: mutable ColumnWithTypeAndName scalar; }; + +/** Get special scalar values + */ +template +class FunctionGetSpecialScalar : public IFunction, WithContext +{ +public: + static constexpr auto name = Scalar::name; + static FunctionPtr create(ContextPtr context_) + { + return std::make_shared>(context_); + } + + static ColumnWithTypeAndName createScalar(ContextPtr context_) + { + if (const auto * block = context_->tryGetLocalScalar(Scalar::scalar_name)) + return block->getByPosition(0); + else if (context_->hasQueryContext()) + { + if (context_->getQueryContext()->hasScalar(Scalar::scalar_name)) + return context_->getQueryContext()->getScalar(Scalar::scalar_name).getByPosition(0); + } + return {DataTypeUInt32().createColumnConst(1, 0), std::make_shared(), Scalar::scalar_name}; + } + + explicit FunctionGetSpecialScalar(ContextPtr context_) + : WithContext(context_), scalar(createScalar(context_)), is_distributed(context_->isDistributed()) + { + } + + String getName() const override + { + return name; + } + + bool isDeterministic() const override { return false; } + + bool isDeterministicInScopeOfQuery() const override + { + return true; + } + + bool isSuitableForConstantFolding() const override { return !is_distributed; } + + size_t getNumberOfArguments() const override + { + return 0; + } + + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override + { + return scalar.type; + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName &, const DataTypePtr &, size_t input_rows_count) const override + { + return ColumnConst::create(scalar.column, input_rows_count); + } + +private: + ColumnWithTypeAndName scalar; + bool is_distributed; +}; + +struct GetShardNum +{ + static constexpr auto name = "shardNum"; + static constexpr auto scalar_name = "_shard_num"; +}; + +struct GetShardCount +{ + static constexpr auto name = "shardCount"; + static constexpr auto scalar_name = "_shard_count"; +}; + } void registerFunctionGetScalar(FunctionFactory & factory) { factory.registerFunction(); + factory.registerFunction>(); + factory.registerFunction>(); } } diff --git a/src/Interpreters/ClusterProxy/IStreamFactory.h b/src/Interpreters/ClusterProxy/IStreamFactory.h index d85e97e5a2e..6360aee2f55 100644 --- a/src/Interpreters/ClusterProxy/IStreamFactory.h +++ b/src/Interpreters/ClusterProxy/IStreamFactory.h @@ -54,7 +54,8 @@ public: const ASTPtr & table_func_ptr, ContextPtr context, std::vector & local_plans, - Shards & remote_shards) = 0; + Shards & remote_shards, + UInt32 shard_count) = 0; }; } diff --git a/src/Interpreters/ClusterProxy/SelectStreamFactory.cpp b/src/Interpreters/ClusterProxy/SelectStreamFactory.cpp index efad9f899d4..961de45c491 100644 --- a/src/Interpreters/ClusterProxy/SelectStreamFactory.cpp +++ b/src/Interpreters/ClusterProxy/SelectStreamFactory.cpp @@ -47,61 +47,6 @@ SelectStreamFactory::SelectStreamFactory( namespace { -/// Special support for the case when `_shard_num` column is used in GROUP BY key expression. -/// This column is a constant for shard. -/// Constant expression with this column may be removed from intermediate header. -/// However, this column is not constant for initiator, and it expect intermediate header has it. -/// -/// To fix it, the following trick is applied. -/// We check all GROUP BY keys which depend only on `_shard_num`. -/// Calculate such expression for current shard if it is used in header. -/// Those columns will be added to modified header as already known constants. -/// -/// For local shard, missed constants will be added by converting actions. -/// For remote shard, RemoteQueryExecutor will automatically add missing constant. -Block evaluateConstantGroupByKeysWithShardNumber( - const ContextPtr & context, const ASTPtr & query_ast, const Block & header, UInt32 shard_num) -{ - Block res; - - ColumnWithTypeAndName shard_num_col; - shard_num_col.type = std::make_shared(); - shard_num_col.column = shard_num_col.type->createColumnConst(0, shard_num); - shard_num_col.name = "_shard_num"; - - if (auto group_by = query_ast->as().groupBy()) - { - for (const auto & elem : group_by->children) - { - String key_name = elem->getColumnName(); - if (header.has(key_name)) - { - auto ast = elem->clone(); - - RequiredSourceColumnsVisitor::Data columns_context; - RequiredSourceColumnsVisitor(columns_context).visit(ast); - - auto required_columns = columns_context.requiredColumns(); - if (required_columns.size() != 1 || required_columns.count("_shard_num") == 0) - continue; - - Block block({shard_num_col}); - auto syntax_result = TreeRewriter(context).analyze(ast, {NameAndTypePair{shard_num_col.name, shard_num_col.type}}); - ExpressionAnalyzer(ast, syntax_result, context).getActions(true, false)->execute(block); - - res.insert(block.getByName(key_name)); - } - } - } - - /// We always add _shard_num constant just in case. - /// For initial query it is considered as a column from table, and may be required by intermediate block. - if (!res.has(shard_num_col.name)) - res.insert(std::move(shard_num_col)); - - return res; -} - ActionsDAGPtr getConvertingDAG(const Block & block, const Block & header) { /// Convert header structure to expected. @@ -128,13 +73,16 @@ std::unique_ptr createLocalPlan( const ASTPtr & query_ast, const Block & header, ContextPtr context, - QueryProcessingStage::Enum processed_stage) + QueryProcessingStage::Enum processed_stage, + UInt32 shard_num, + UInt32 shard_count) { checkStackSize(); auto query_plan = std::make_unique(); - InterpreterSelectQuery interpreter(query_ast, context, SelectQueryOptions(processed_stage)); + InterpreterSelectQuery interpreter( + query_ast, context, SelectQueryOptions(processed_stage).setShardInfo(shard_num, shard_count)); interpreter.buildQueryPlan(*query_plan); addConvertingActions(*query_plan, header); @@ -151,38 +99,27 @@ void SelectStreamFactory::createForShard( const ASTPtr & table_func_ptr, ContextPtr context, std::vector & local_plans, - Shards & remote_shards) + Shards & remote_shards, + UInt32 shard_count) { auto modified_query_ast = query_ast->clone(); - auto modified_header = header; if (has_virtual_shard_num_column) - { VirtualColumnUtils::rewriteEntityInAst(modified_query_ast, "_shard_num", shard_info.shard_num, "toUInt32"); - auto shard_num_constants = evaluateConstantGroupByKeysWithShardNumber(context, query_ast, modified_header, shard_info.shard_num); - - for (auto & col : shard_num_constants) - { - if (modified_header.has(col.name)) - modified_header.getByName(col.name).column = std::move(col.column); - else - modified_header.insert(std::move(col)); - } - } auto emplace_local_stream = [&]() { - local_plans.emplace_back(createLocalPlan(modified_query_ast, modified_header, context, processed_stage)); - addConvertingActions(*local_plans.back(), header); + local_plans.emplace_back(createLocalPlan(modified_query_ast, header, context, processed_stage, shard_info.shard_num, shard_count)); }; - auto emplace_remote_stream = [&]() + auto emplace_remote_stream = [&](bool lazy = false, UInt32 local_delay = 0) { remote_shards.emplace_back(Shard{ .query = modified_query_ast, - .header = modified_header, + .header = header, .shard_num = shard_info.shard_num, .pool = shard_info.pool, - .lazy = false + .lazy = lazy, + .local_delay = local_delay, }); }; @@ -273,15 +210,7 @@ void SelectStreamFactory::createForShard( /// Try our luck with remote replicas, but if they are stale too, then fallback to local replica. /// Do it lazily to avoid connecting in the main thread. - - remote_shards.emplace_back(Shard{ - .query = modified_query_ast, - .header = modified_header, - .shard_num = shard_info.shard_num, - .pool = shard_info.pool, - .lazy = true, - .local_delay = local_delay - }); + emplace_remote_stream(true /* lazy */, local_delay); } else emplace_remote_stream(); diff --git a/src/Interpreters/ClusterProxy/SelectStreamFactory.h b/src/Interpreters/ClusterProxy/SelectStreamFactory.h index d041ac8ea5f..dda6fb96f01 100644 --- a/src/Interpreters/ClusterProxy/SelectStreamFactory.h +++ b/src/Interpreters/ClusterProxy/SelectStreamFactory.h @@ -26,7 +26,8 @@ public: const ASTPtr & table_func_ptr, ContextPtr context, std::vector & local_plans, - Shards & remote_shards) override; + Shards & remote_shards, + UInt32 shard_count) override; private: const Block header; diff --git a/src/Interpreters/ClusterProxy/executeQuery.cpp b/src/Interpreters/ClusterProxy/executeQuery.cpp index d3a1b40a8e3..95b279fd59b 100644 --- a/src/Interpreters/ClusterProxy/executeQuery.cpp +++ b/src/Interpreters/ClusterProxy/executeQuery.cpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace DB @@ -165,12 +166,14 @@ void executeQuery( stream_factory.createForShard(shard_info, query_ast_for_shard, main_table, table_func_ptr, - new_context, plans, remote_shards); + new_context, plans, remote_shards, shards); } if (!remote_shards.empty()) { - const Scalars & scalars = context->hasQueryContext() ? context->getQueryContext()->getScalars() : Scalars{}; + Scalars scalars = context->hasQueryContext() ? context->getQueryContext()->getScalars() : Scalars{}; + scalars.emplace( + "_shard_count", Block{{DataTypeUInt32().createColumnConst(1, shards), std::make_shared(), "_shard_count"}}); auto external_tables = context->getExternalTables(); auto plan = std::make_unique(); @@ -182,9 +185,10 @@ void executeQuery( table_func_ptr, new_context, throttler, - scalars, + std::move(scalars), std::move(external_tables), - log); + log, + shards); read_from_remote->setStepDescription("Read from remote replica"); plan->addStep(std::move(read_from_remote)); diff --git a/src/Interpreters/Context.cpp b/src/Interpreters/Context.cpp index 33ebfbd21e0..43d2e4712b7 100644 --- a/src/Interpreters/Context.cpp +++ b/src/Interpreters/Context.cpp @@ -997,6 +997,13 @@ const Block & Context::getScalar(const String & name) const return it->second; } +const Block * Context::tryGetLocalScalar(const String & name) const +{ + auto it = local_scalars.find(name); + if (local_scalars.end() == it) + return nullptr; + return &it->second; +} Tables Context::getExternalTables() const { @@ -1056,6 +1063,13 @@ void Context::addScalar(const String & name, const Block & block) } +void Context::addLocalScalar(const String & name, const Block & block) +{ + assert(!isGlobalContext() || getApplicationType() == ApplicationType::LOCAL); + local_scalars[name] = block; +} + + bool Context::hasScalar(const String & name) const { assert(!isGlobalContext() || getApplicationType() == ApplicationType::LOCAL); diff --git a/src/Interpreters/Context.h b/src/Interpreters/Context.h index 591d1dba46f..144f8b62d51 100644 --- a/src/Interpreters/Context.h +++ b/src/Interpreters/Context.h @@ -197,6 +197,7 @@ private: /// Thus, used in HTTP interface. If not specified - then some globally default format is used. TemporaryTablesMapping external_tables_mapping; Scalars scalars; + Scalars local_scalars; /// Fields for distributed s3 function std::optional next_task_callback; @@ -455,6 +456,9 @@ public: void addScalar(const String & name, const Block & block); bool hasScalar(const String & name) const; + const Block * tryGetLocalScalar(const String & name) const; + void addLocalScalar(const String & name, const Block & block); + const QueryAccessInfo & getQueryAccessInfo() const { return query_access_info; } void addQueryAccessInfo( const String & quoted_database_name, diff --git a/src/Interpreters/ExpressionAnalyzer.cpp b/src/Interpreters/ExpressionAnalyzer.cpp index 1c1309ec916..66c1cb9ad7b 100644 --- a/src/Interpreters/ExpressionAnalyzer.cpp +++ b/src/Interpreters/ExpressionAnalyzer.cpp @@ -249,7 +249,7 @@ void ExpressionAnalyzer::analyzeAggregation() throw Exception("Unknown identifier (in GROUP BY): " + column_name, ErrorCodes::UNKNOWN_IDENTIFIER); /// Only removes constant keys if it's an initiator or distributed_group_by_no_merge is enabled. - if (getContext()->getClientInfo().distributed_depth == 0 && settings.distributed_group_by_no_merge > 0) + if (getContext()->getClientInfo().distributed_depth == 0 || settings.distributed_group_by_no_merge > 0) { /// Constant expressions have non-null column pointer at this stage. if (node->column && isColumnConst(*node->column)) diff --git a/src/Interpreters/IInterpreterUnionOrSelectQuery.h b/src/Interpreters/IInterpreterUnionOrSelectQuery.h index 0b07f27e14a..cc960e748f6 100644 --- a/src/Interpreters/IInterpreterUnionOrSelectQuery.h +++ b/src/Interpreters/IInterpreterUnionOrSelectQuery.h @@ -4,6 +4,7 @@ #include #include #include +#include namespace DB { @@ -16,6 +17,14 @@ public: , options(options_) , max_streams(context->getSettingsRef().max_threads) { + if (options.shard_num) + context->addLocalScalar( + "_shard_num", + Block{{DataTypeUInt32().createColumnConst(1, *options.shard_num), std::make_shared(), "_shard_num"}}); + if (options.shard_count) + context->addLocalScalar( + "_shard_count", + Block{{DataTypeUInt32().createColumnConst(1, *options.shard_count), std::make_shared(), "_shard_count"}}); } virtual void buildQueryPlan(QueryPlan & query_plan) = 0; diff --git a/src/Interpreters/SelectQueryOptions.h b/src/Interpreters/SelectQueryOptions.h index 52ce7c83741..89402609e1c 100644 --- a/src/Interpreters/SelectQueryOptions.h +++ b/src/Interpreters/SelectQueryOptions.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace DB { @@ -45,6 +46,12 @@ struct SelectQueryOptions bool is_subquery = false; // non-subquery can also have subquery_depth > 0, e.g. insert select bool with_all_cols = false; /// asterisk include materialized and aliased columns + /// These two fields are used to evaluate getShardNum() and getShardCount() function when + /// prefer_localhost_replica == 1 and local instance is selected. They are needed because local + /// instance might have multiple shards and scalars can only hold one value. + std::optional shard_num; + std::optional shard_count; + SelectQueryOptions( QueryProcessingStage::Enum stage = QueryProcessingStage::Complete, size_t depth = 0, @@ -124,6 +131,13 @@ struct SelectQueryOptions with_all_cols = value; return *this; } + + SelectQueryOptions & setShardInfo(UInt32 shard_num_, UInt32 shard_count_) + { + shard_num = shard_num_; + shard_count = shard_count_; + return *this; + } }; } diff --git a/src/Processors/QueryPlan/ReadFromRemote.cpp b/src/Processors/QueryPlan/ReadFromRemote.cpp index 63270237e44..506ef795473 100644 --- a/src/Processors/QueryPlan/ReadFromRemote.cpp +++ b/src/Processors/QueryPlan/ReadFromRemote.cpp @@ -67,13 +67,16 @@ static std::unique_ptr createLocalPlan( const ASTPtr & query_ast, const Block & header, ContextPtr context, - QueryProcessingStage::Enum processed_stage) + QueryProcessingStage::Enum processed_stage, + UInt32 shard_num, + UInt32 shard_count) { checkStackSize(); auto query_plan = std::make_unique(); - InterpreterSelectQuery interpreter(query_ast, context, SelectQueryOptions(processed_stage)); + InterpreterSelectQuery interpreter( + query_ast, context, SelectQueryOptions(processed_stage).setShardInfo(shard_num, shard_count)); interpreter.buildQueryPlan(*query_plan); addConvertingActions(*query_plan, header); @@ -92,7 +95,8 @@ ReadFromRemote::ReadFromRemote( ThrottlerPtr throttler_, Scalars scalars_, Tables external_tables_, - Poco::Logger * log_) + Poco::Logger * log_, + UInt32 shard_count_) : ISourceStep(DataStream{.header = std::move(header_)}) , shards(std::move(shards_)) , stage(stage_) @@ -103,6 +107,7 @@ ReadFromRemote::ReadFromRemote( , scalars(std::move(scalars_)) , external_tables(std::move(external_tables_)) , log(log_) + , shard_count(shard_count_) { } @@ -119,12 +124,12 @@ void ReadFromRemote::addLazyPipe(Pipes & pipes, const ClusterProxy::IStreamFacto } auto lazily_create_stream = [ - pool = shard.pool, shard_num = shard.shard_num, query = shard.query, header = shard.header, + pool = shard.pool, shard_num = shard.shard_num, shard_count = shard_count, query = shard.query, header = shard.header, context = context, throttler = throttler, main_table = main_table, table_func_ptr = table_func_ptr, scalars = scalars, external_tables = external_tables, stage = stage, local_delay = shard.local_delay, - add_agg_info, add_totals, add_extremes, async_read]() + add_agg_info, add_totals, add_extremes, async_read]() mutable -> Pipe { auto current_settings = context->getSettingsRef(); @@ -157,7 +162,7 @@ void ReadFromRemote::addLazyPipe(Pipes & pipes, const ClusterProxy::IStreamFacto if (try_results.empty() || local_delay < max_remote_delay) { - auto plan = createLocalPlan(query, header, context, stage); + auto plan = createLocalPlan(query, header, context, stage, shard_num, shard_count); return QueryPipeline::getPipe(std::move(*plan->buildQueryPipeline( QueryPlanOptimizationSettings::fromContext(context), BuildQueryPipelineSettings::fromContext(context)))); @@ -171,6 +176,8 @@ void ReadFromRemote::addLazyPipe(Pipes & pipes, const ClusterProxy::IStreamFacto String query_string = formattedAST(query); + scalars["_shard_num"] + = Block{{DataTypeUInt32().createColumnConst(1, shard_num), std::make_shared(), "_shard_num"}}; auto remote_query_executor = std::make_shared( pool, std::move(connections), query_string, header, context, throttler, scalars, external_tables, stage); @@ -197,6 +204,8 @@ void ReadFromRemote::addPipe(Pipes & pipes, const ClusterProxy::IStreamFactory:: String query_string = formattedAST(shard.query); + scalars["_shard_num"] + = Block{{DataTypeUInt32().createColumnConst(1, shard.shard_num), std::make_shared(), "_shard_num"}}; auto remote_query_executor = std::make_shared( shard.pool, query_string, shard.header, context, throttler, scalars, external_tables, stage); remote_query_executor->setLogger(log); diff --git a/src/Processors/QueryPlan/ReadFromRemote.h b/src/Processors/QueryPlan/ReadFromRemote.h index 61099299c36..ba0060d5470 100644 --- a/src/Processors/QueryPlan/ReadFromRemote.h +++ b/src/Processors/QueryPlan/ReadFromRemote.h @@ -29,7 +29,8 @@ public: ThrottlerPtr throttler_, Scalars scalars_, Tables external_tables_, - Poco::Logger * log_); + Poco::Logger * log_, + UInt32 shard_count_); String getName() const override { return "ReadFromRemote"; } @@ -50,6 +51,7 @@ private: Poco::Logger * log; + UInt32 shard_count; void addLazyPipe(Pipes & pipes, const ClusterProxy::IStreamFactory::Shard & shard); void addPipe(Pipes & pipes, const ClusterProxy::IStreamFactory::Shard & shard); }; diff --git a/tests/queries/0_stateless/01860_Distributed__shard_num_GROUP_BY.sql b/tests/queries/0_stateless/01860_Distributed__shard_num_GROUP_BY.sql index 91215fd8ee6..d8a86b7799e 100644 --- a/tests/queries/0_stateless/01860_Distributed__shard_num_GROUP_BY.sql +++ b/tests/queries/0_stateless/01860_Distributed__shard_num_GROUP_BY.sql @@ -11,6 +11,4 @@ SELECT _shard_num + dummy s, count() FROM remote('127.0.0.{1,2}', system.one) GR SELECT _shard_num FROM remote('127.0.0.{1,2}', system.one) ORDER BY _shard_num; SELECT _shard_num s FROM remote('127.0.0.{1,2}', system.one) ORDER BY _shard_num; -SELECT _shard_num s, count() FROM remote('127.0.0.{1,2}', system.one) GROUP BY s order by s; - -select materialize(_shard_num), * from remote('127.{1,2}', system.one) limit 1 by dummy format Null; +SELECT _shard_num, count() FROM remote('127.0.0.{1,2}', system.one) GROUP BY _shard_num order by _shard_num; diff --git a/tests/queries/0_stateless/02001_shard_num_shard_count.sql b/tests/queries/0_stateless/02001_shard_num_shard_count.sql new file mode 100644 index 00000000000..daf1084a614 --- /dev/null +++ b/tests/queries/0_stateless/02001_shard_num_shard_count.sql @@ -0,0 +1,3 @@ +select shardNum() n, shardCount() c; +select shardNum() n, shardCount() c from remote('127.0.0.{1,2,3}', system.one) order by n settings prefer_localhost_replica = 0; +select shardNum() n, shardCount() c from remote('127.0.0.{1,2,3}', system.one) order by n settings prefer_localhost_replica = 1;