Add shardNum() and shardCount() functions

This commit is contained in:
Amos Bird 2021-07-31 15:45:26 +08:00
parent cd302eacc1
commit 479d4fa991
No known key found for this signature in database
GPG Key ID: 80D430DCBECFEDB4
14 changed files with 168 additions and 101 deletions

View File

@ -2,6 +2,7 @@
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnString.h>
#include <Interpreters/Context.h>
#include <Common/Macros.h>
@ -60,11 +61,89 @@ private:
mutable ColumnWithTypeAndName scalar;
};
/** Get special scalar values
*/
template <typename Scalar>
class FunctionGetSpecialScalar : public IFunction, WithContext
{
public:
static constexpr auto name = Scalar::name;
static FunctionPtr create(ContextPtr context_)
{
return std::make_shared<FunctionGetSpecialScalar<Scalar>>(context_);
}
static ColumnWithTypeAndName createScalar(ContextPtr context_)
{
if (const auto * block = context_->tryGetLocalScalar(Scalar::scalar_name))
return block->getByPosition(0);
else if (context_->hasQueryContext())
{
if (context_->getQueryContext()->hasScalar(Scalar::scalar_name))
return context_->getQueryContext()->getScalar(Scalar::scalar_name).getByPosition(0);
}
return {DataTypeUInt32().createColumnConst(1, 0), std::make_shared<DataTypeUInt32>(), Scalar::scalar_name};
}
explicit FunctionGetSpecialScalar(ContextPtr context_)
: WithContext(context_), scalar(createScalar(context_)), is_distributed(context_->isDistributed())
{
}
String getName() const override
{
return name;
}
bool isDeterministic() const override { return false; }
bool isDeterministicInScopeOfQuery() const override
{
return true;
}
bool isSuitableForConstantFolding() const override { return !is_distributed; }
size_t getNumberOfArguments() const override
{
return 0;
}
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const override
{
return scalar.type;
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName &, const DataTypePtr &, size_t input_rows_count) const override
{
return ColumnConst::create(scalar.column, input_rows_count);
}
private:
ColumnWithTypeAndName scalar;
bool is_distributed;
};
struct GetShardNum
{
static constexpr auto name = "shardNum";
static constexpr auto scalar_name = "_shard_num";
};
struct GetShardCount
{
static constexpr auto name = "shardCount";
static constexpr auto scalar_name = "_shard_count";
};
}
void registerFunctionGetScalar(FunctionFactory & factory)
{
factory.registerFunction<FunctionGetScalar>();
factory.registerFunction<FunctionGetSpecialScalar<GetShardNum>>();
factory.registerFunction<FunctionGetSpecialScalar<GetShardCount>>();
}
}

View File

@ -54,7 +54,8 @@ public:
const ASTPtr & table_func_ptr,
ContextPtr context,
std::vector<QueryPlanPtr> & local_plans,
Shards & remote_shards) = 0;
Shards & remote_shards,
UInt32 shard_count) = 0;
};
}

View File

@ -47,61 +47,6 @@ SelectStreamFactory::SelectStreamFactory(
namespace
{
/// Special support for the case when `_shard_num` column is used in GROUP BY key expression.
/// This column is a constant for shard.
/// Constant expression with this column may be removed from intermediate header.
/// However, this column is not constant for initiator, and it expect intermediate header has it.
///
/// To fix it, the following trick is applied.
/// We check all GROUP BY keys which depend only on `_shard_num`.
/// Calculate such expression for current shard if it is used in header.
/// Those columns will be added to modified header as already known constants.
///
/// For local shard, missed constants will be added by converting actions.
/// For remote shard, RemoteQueryExecutor will automatically add missing constant.
Block evaluateConstantGroupByKeysWithShardNumber(
const ContextPtr & context, const ASTPtr & query_ast, const Block & header, UInt32 shard_num)
{
Block res;
ColumnWithTypeAndName shard_num_col;
shard_num_col.type = std::make_shared<DataTypeUInt32>();
shard_num_col.column = shard_num_col.type->createColumnConst(0, shard_num);
shard_num_col.name = "_shard_num";
if (auto group_by = query_ast->as<ASTSelectQuery &>().groupBy())
{
for (const auto & elem : group_by->children)
{
String key_name = elem->getColumnName();
if (header.has(key_name))
{
auto ast = elem->clone();
RequiredSourceColumnsVisitor::Data columns_context;
RequiredSourceColumnsVisitor(columns_context).visit(ast);
auto required_columns = columns_context.requiredColumns();
if (required_columns.size() != 1 || required_columns.count("_shard_num") == 0)
continue;
Block block({shard_num_col});
auto syntax_result = TreeRewriter(context).analyze(ast, {NameAndTypePair{shard_num_col.name, shard_num_col.type}});
ExpressionAnalyzer(ast, syntax_result, context).getActions(true, false)->execute(block);
res.insert(block.getByName(key_name));
}
}
}
/// We always add _shard_num constant just in case.
/// For initial query it is considered as a column from table, and may be required by intermediate block.
if (!res.has(shard_num_col.name))
res.insert(std::move(shard_num_col));
return res;
}
ActionsDAGPtr getConvertingDAG(const Block & block, const Block & header)
{
/// Convert header structure to expected.
@ -128,13 +73,16 @@ std::unique_ptr<QueryPlan> createLocalPlan(
const ASTPtr & query_ast,
const Block & header,
ContextPtr context,
QueryProcessingStage::Enum processed_stage)
QueryProcessingStage::Enum processed_stage,
UInt32 shard_num,
UInt32 shard_count)
{
checkStackSize();
auto query_plan = std::make_unique<QueryPlan>();
InterpreterSelectQuery interpreter(query_ast, context, SelectQueryOptions(processed_stage));
InterpreterSelectQuery interpreter(
query_ast, context, SelectQueryOptions(processed_stage).setShardInfo(shard_num, shard_count));
interpreter.buildQueryPlan(*query_plan);
addConvertingActions(*query_plan, header);
@ -151,38 +99,27 @@ void SelectStreamFactory::createForShard(
const ASTPtr & table_func_ptr,
ContextPtr context,
std::vector<QueryPlanPtr> & local_plans,
Shards & remote_shards)
Shards & remote_shards,
UInt32 shard_count)
{
auto modified_query_ast = query_ast->clone();
auto modified_header = header;
if (has_virtual_shard_num_column)
{
VirtualColumnUtils::rewriteEntityInAst(modified_query_ast, "_shard_num", shard_info.shard_num, "toUInt32");
auto shard_num_constants = evaluateConstantGroupByKeysWithShardNumber(context, query_ast, modified_header, shard_info.shard_num);
for (auto & col : shard_num_constants)
{
if (modified_header.has(col.name))
modified_header.getByName(col.name).column = std::move(col.column);
else
modified_header.insert(std::move(col));
}
}
auto emplace_local_stream = [&]()
{
local_plans.emplace_back(createLocalPlan(modified_query_ast, modified_header, context, processed_stage));
addConvertingActions(*local_plans.back(), header);
local_plans.emplace_back(createLocalPlan(modified_query_ast, header, context, processed_stage, shard_info.shard_num, shard_count));
};
auto emplace_remote_stream = [&]()
auto emplace_remote_stream = [&](bool lazy = false, UInt32 local_delay = 0)
{
remote_shards.emplace_back(Shard{
.query = modified_query_ast,
.header = modified_header,
.header = header,
.shard_num = shard_info.shard_num,
.pool = shard_info.pool,
.lazy = false
.lazy = lazy,
.local_delay = local_delay,
});
};
@ -273,15 +210,7 @@ void SelectStreamFactory::createForShard(
/// Try our luck with remote replicas, but if they are stale too, then fallback to local replica.
/// Do it lazily to avoid connecting in the main thread.
remote_shards.emplace_back(Shard{
.query = modified_query_ast,
.header = modified_header,
.shard_num = shard_info.shard_num,
.pool = shard_info.pool,
.lazy = true,
.local_delay = local_delay
});
emplace_remote_stream(true /* lazy */, local_delay);
}
else
emplace_remote_stream();

View File

@ -26,7 +26,8 @@ public:
const ASTPtr & table_func_ptr,
ContextPtr context,
std::vector<QueryPlanPtr> & local_plans,
Shards & remote_shards) override;
Shards & remote_shards,
UInt32 shard_count) override;
private:
const Block header;

View File

@ -11,6 +11,7 @@
#include <Processors/QueryPlan/ReadFromRemote.h>
#include <Processors/QueryPlan/UnionStep.h>
#include <Storages/SelectQueryInfo.h>
#include <DataTypes/DataTypesNumber.h>
namespace DB
@ -165,12 +166,14 @@ void executeQuery(
stream_factory.createForShard(shard_info,
query_ast_for_shard, main_table, table_func_ptr,
new_context, plans, remote_shards);
new_context, plans, remote_shards, shards);
}
if (!remote_shards.empty())
{
const Scalars & scalars = context->hasQueryContext() ? context->getQueryContext()->getScalars() : Scalars{};
Scalars scalars = context->hasQueryContext() ? context->getQueryContext()->getScalars() : Scalars{};
scalars.emplace(
"_shard_count", Block{{DataTypeUInt32().createColumnConst(1, shards), std::make_shared<DataTypeUInt32>(), "_shard_count"}});
auto external_tables = context->getExternalTables();
auto plan = std::make_unique<QueryPlan>();
@ -182,9 +185,10 @@ void executeQuery(
table_func_ptr,
new_context,
throttler,
scalars,
std::move(scalars),
std::move(external_tables),
log);
log,
shards);
read_from_remote->setStepDescription("Read from remote replica");
plan->addStep(std::move(read_from_remote));

View File

@ -997,6 +997,13 @@ const Block & Context::getScalar(const String & name) const
return it->second;
}
const Block * Context::tryGetLocalScalar(const String & name) const
{
auto it = local_scalars.find(name);
if (local_scalars.end() == it)
return nullptr;
return &it->second;
}
Tables Context::getExternalTables() const
{
@ -1056,6 +1063,13 @@ void Context::addScalar(const String & name, const Block & block)
}
void Context::addLocalScalar(const String & name, const Block & block)
{
assert(!isGlobalContext() || getApplicationType() == ApplicationType::LOCAL);
local_scalars[name] = block;
}
bool Context::hasScalar(const String & name) const
{
assert(!isGlobalContext() || getApplicationType() == ApplicationType::LOCAL);

View File

@ -197,6 +197,7 @@ private:
/// Thus, used in HTTP interface. If not specified - then some globally default format is used.
TemporaryTablesMapping external_tables_mapping;
Scalars scalars;
Scalars local_scalars;
/// Fields for distributed s3 function
std::optional<ReadTaskCallback> next_task_callback;
@ -455,6 +456,9 @@ public:
void addScalar(const String & name, const Block & block);
bool hasScalar(const String & name) const;
const Block * tryGetLocalScalar(const String & name) const;
void addLocalScalar(const String & name, const Block & block);
const QueryAccessInfo & getQueryAccessInfo() const { return query_access_info; }
void addQueryAccessInfo(
const String & quoted_database_name,

View File

@ -249,7 +249,7 @@ void ExpressionAnalyzer::analyzeAggregation()
throw Exception("Unknown identifier (in GROUP BY): " + column_name, ErrorCodes::UNKNOWN_IDENTIFIER);
/// Only removes constant keys if it's an initiator or distributed_group_by_no_merge is enabled.
if (getContext()->getClientInfo().distributed_depth == 0 && settings.distributed_group_by_no_merge > 0)
if (getContext()->getClientInfo().distributed_depth == 0 || settings.distributed_group_by_no_merge > 0)
{
/// Constant expressions have non-null column pointer at this stage.
if (node->column && isColumnConst(*node->column))

View File

@ -4,6 +4,7 @@
#include <Interpreters/IInterpreter.h>
#include <Interpreters/SelectQueryOptions.h>
#include <Parsers/IAST_fwd.h>
#include <DataTypes/DataTypesNumber.h>
namespace DB
{
@ -16,6 +17,14 @@ public:
, options(options_)
, max_streams(context->getSettingsRef().max_threads)
{
if (options.shard_num)
context->addLocalScalar(
"_shard_num",
Block{{DataTypeUInt32().createColumnConst(1, *options.shard_num), std::make_shared<DataTypeUInt32>(), "_shard_num"}});
if (options.shard_count)
context->addLocalScalar(
"_shard_count",
Block{{DataTypeUInt32().createColumnConst(1, *options.shard_count), std::make_shared<DataTypeUInt32>(), "_shard_count"}});
}
virtual void buildQueryPlan(QueryPlan & query_plan) = 0;

View File

@ -1,6 +1,7 @@
#pragma once
#include <Core/QueryProcessingStage.h>
#include <optional>
namespace DB
{
@ -45,6 +46,12 @@ struct SelectQueryOptions
bool is_subquery = false; // non-subquery can also have subquery_depth > 0, e.g. insert select
bool with_all_cols = false; /// asterisk include materialized and aliased columns
/// These two fields are used to evaluate getShardNum() and getShardCount() function when
/// prefer_localhost_replica == 1 and local instance is selected. They are needed because local
/// instance might have multiple shards and scalars can only hold one value.
std::optional<UInt32> shard_num;
std::optional<UInt32> shard_count;
SelectQueryOptions(
QueryProcessingStage::Enum stage = QueryProcessingStage::Complete,
size_t depth = 0,
@ -124,6 +131,13 @@ struct SelectQueryOptions
with_all_cols = value;
return *this;
}
SelectQueryOptions & setShardInfo(UInt32 shard_num_, UInt32 shard_count_)
{
shard_num = shard_num_;
shard_count = shard_count_;
return *this;
}
};
}

View File

@ -67,13 +67,16 @@ static std::unique_ptr<QueryPlan> createLocalPlan(
const ASTPtr & query_ast,
const Block & header,
ContextPtr context,
QueryProcessingStage::Enum processed_stage)
QueryProcessingStage::Enum processed_stage,
UInt32 shard_num,
UInt32 shard_count)
{
checkStackSize();
auto query_plan = std::make_unique<QueryPlan>();
InterpreterSelectQuery interpreter(query_ast, context, SelectQueryOptions(processed_stage));
InterpreterSelectQuery interpreter(
query_ast, context, SelectQueryOptions(processed_stage).setShardInfo(shard_num, shard_count));
interpreter.buildQueryPlan(*query_plan);
addConvertingActions(*query_plan, header);
@ -92,7 +95,8 @@ ReadFromRemote::ReadFromRemote(
ThrottlerPtr throttler_,
Scalars scalars_,
Tables external_tables_,
Poco::Logger * log_)
Poco::Logger * log_,
UInt32 shard_count_)
: ISourceStep(DataStream{.header = std::move(header_)})
, shards(std::move(shards_))
, stage(stage_)
@ -103,6 +107,7 @@ ReadFromRemote::ReadFromRemote(
, scalars(std::move(scalars_))
, external_tables(std::move(external_tables_))
, log(log_)
, shard_count(shard_count_)
{
}
@ -119,12 +124,12 @@ void ReadFromRemote::addLazyPipe(Pipes & pipes, const ClusterProxy::IStreamFacto
}
auto lazily_create_stream = [
pool = shard.pool, shard_num = shard.shard_num, query = shard.query, header = shard.header,
pool = shard.pool, shard_num = shard.shard_num, shard_count = shard_count, query = shard.query, header = shard.header,
context = context, throttler = throttler,
main_table = main_table, table_func_ptr = table_func_ptr,
scalars = scalars, external_tables = external_tables,
stage = stage, local_delay = shard.local_delay,
add_agg_info, add_totals, add_extremes, async_read]()
add_agg_info, add_totals, add_extremes, async_read]() mutable
-> Pipe
{
auto current_settings = context->getSettingsRef();
@ -157,7 +162,7 @@ void ReadFromRemote::addLazyPipe(Pipes & pipes, const ClusterProxy::IStreamFacto
if (try_results.empty() || local_delay < max_remote_delay)
{
auto plan = createLocalPlan(query, header, context, stage);
auto plan = createLocalPlan(query, header, context, stage, shard_num, shard_count);
return QueryPipeline::getPipe(std::move(*plan->buildQueryPipeline(
QueryPlanOptimizationSettings::fromContext(context),
BuildQueryPipelineSettings::fromContext(context))));
@ -171,6 +176,8 @@ void ReadFromRemote::addLazyPipe(Pipes & pipes, const ClusterProxy::IStreamFacto
String query_string = formattedAST(query);
scalars["_shard_num"]
= Block{{DataTypeUInt32().createColumnConst(1, shard_num), std::make_shared<DataTypeUInt32>(), "_shard_num"}};
auto remote_query_executor = std::make_shared<RemoteQueryExecutor>(
pool, std::move(connections), query_string, header, context, throttler, scalars, external_tables, stage);
@ -197,6 +204,8 @@ void ReadFromRemote::addPipe(Pipes & pipes, const ClusterProxy::IStreamFactory::
String query_string = formattedAST(shard.query);
scalars["_shard_num"]
= Block{{DataTypeUInt32().createColumnConst(1, shard.shard_num), std::make_shared<DataTypeUInt32>(), "_shard_num"}};
auto remote_query_executor = std::make_shared<RemoteQueryExecutor>(
shard.pool, query_string, shard.header, context, throttler, scalars, external_tables, stage);
remote_query_executor->setLogger(log);

View File

@ -29,7 +29,8 @@ public:
ThrottlerPtr throttler_,
Scalars scalars_,
Tables external_tables_,
Poco::Logger * log_);
Poco::Logger * log_,
UInt32 shard_count_);
String getName() const override { return "ReadFromRemote"; }
@ -50,6 +51,7 @@ private:
Poco::Logger * log;
UInt32 shard_count;
void addLazyPipe(Pipes & pipes, const ClusterProxy::IStreamFactory::Shard & shard);
void addPipe(Pipes & pipes, const ClusterProxy::IStreamFactory::Shard & shard);
};

View File

@ -11,6 +11,4 @@ SELECT _shard_num + dummy s, count() FROM remote('127.0.0.{1,2}', system.one) GR
SELECT _shard_num FROM remote('127.0.0.{1,2}', system.one) ORDER BY _shard_num;
SELECT _shard_num s FROM remote('127.0.0.{1,2}', system.one) ORDER BY _shard_num;
SELECT _shard_num s, count() FROM remote('127.0.0.{1,2}', system.one) GROUP BY s order by s;
select materialize(_shard_num), * from remote('127.{1,2}', system.one) limit 1 by dummy format Null;
SELECT _shard_num, count() FROM remote('127.0.0.{1,2}', system.one) GROUP BY _shard_num order by _shard_num;

View File

@ -0,0 +1,3 @@
select shardNum() n, shardCount() c;
select shardNum() n, shardCount() c from remote('127.0.0.{1,2,3}', system.one) order by n settings prefer_localhost_replica = 0;
select shardNum() n, shardCount() c from remote('127.0.0.{1,2,3}', system.one) order by n settings prefer_localhost_replica = 1;