diff --git a/src/Processors/QueryPlan/AggregatingStep.cpp b/src/Processors/QueryPlan/AggregatingStep.cpp index a11451396bc..db78f236a61 100644 --- a/src/Processors/QueryPlan/AggregatingStep.cpp +++ b/src/Processors/QueryPlan/AggregatingStep.cpp @@ -465,11 +465,12 @@ void AggregatingStep::transformPipeline(QueryPipelineBuilder & pipeline, const B void AggregatingStep::describeActions(FormatSettings & settings) const { params.explain(settings.out, settings.offset); + String prefix(settings.offset, settings.indent_char); if (!sort_description_for_merging.empty()) { - String prefix(settings.offset, settings.indent_char); settings.out << prefix << "Order: " << dumpSortDescription(sort_description_for_merging) << '\n'; } + settings.out << prefix << "Skip merging: " << skip_merging << '\n'; } void AggregatingStep::describeActions(JSONBuilder::JSONMap & map) const @@ -477,6 +478,7 @@ void AggregatingStep::describeActions(JSONBuilder::JSONMap & map) const params.explain(map); if (!sort_description_for_merging.empty()) map.add("Order", dumpSortDescription(sort_description_for_merging)); + map.add("Skip merging", skip_merging); } void AggregatingStep::describePipeline(FormatSettings & settings) const diff --git a/src/Processors/QueryPlan/Optimizations/useDataParallelAggregation.cpp b/src/Processors/QueryPlan/Optimizations/useDataParallelAggregation.cpp index 445e9e4806f..9d520ac5c35 100644 --- a/src/Processors/QueryPlan/Optimizations/useDataParallelAggregation.cpp +++ b/src/Processors/QueryPlan/Optimizations/useDataParallelAggregation.cpp @@ -2,38 +2,216 @@ #include #include +#include +#include #include +#include +#include + using namespace DB; namespace { -bool isPartitionKeySuitsGroupByKey(const ReadFromMergeTree & reading, const AggregatingStep & aggregating) -{ - const auto & gb_keys = aggregating.getParams().keys; - if (aggregating.isGroupingSets() || gb_keys.size() != 1) - return false; +using NodeSet = std::unordered_set; +using NodeMap = std::unordered_map; - const auto & pkey_nodes = reading.getStorageMetadata()->getPartitionKey().expression->getActionsDAG().getNodes(); - LOG_DEBUG(&Poco::Logger::get("debug"), "{}", reading.getStorageMetadata()->getPartitionKey().expression->getActionsDAG().dumpDAG()); - if (!pkey_nodes.empty()) +struct Frame +{ + const ActionsDAG::Node * node = nullptr; + size_t next_child = 0; +}; + +auto print_node = [](const ActionsDAG::Node * node_) +{ + String children; + for (const auto & child : node_->children) + children += fmt::format("{}, ", static_cast(child)); + LOG_DEBUG( + &Poco::Logger::get("debug"), + "current node {} {} {} {}", + static_cast(node_), + node_->result_name, + node_->type, + children); +}; + +bool isInjectiveFunction(const ActionsDAG::Node * node) +{ + if (node->function_base->isInjective({})) + return true; + + size_t fixed_args = 0; + for (const auto & child : node->children) + if (child->type == ActionsDAG::ActionType::COLUMN) + ++fixed_args; + static const std::vector injective = {"plus", "minus"}; + return (fixed_args + 1 >= node->children.size()) && (std::ranges::find(injective, node->function_base->getName()) != injective.end()); +} + +void removeInjectiveColumnsFromResultsRecursively( + const ActionsDAGPtr & actions, const ActionsDAG::Node * cur_node, NodeSet & irreducible, NodeSet & visited) +{ + if (visited.contains(cur_node)) + return; + visited.insert(cur_node); + + print_node(cur_node); + + switch (cur_node->type) { - const auto & func_node = pkey_nodes.back(); - LOG_DEBUG(&Poco::Logger::get("debug"), "{} {} {}", func_node.type, func_node.is_deterministic, func_node.children.size()); - if (func_node.type == ActionsDAG::ActionType::FUNCTION && func_node.function->getName() == "modulo" - && func_node.children.size() == 2) + case ActionsDAG::ActionType::ALIAS: + assert(cur_node->children.size() == 1); + removeInjectiveColumnsFromResultsRecursively(actions, cur_node->children.at(0), irreducible, visited); + break; + case ActionsDAG::ActionType::ARRAY_JOIN: + break; + case ActionsDAG::ActionType::COLUMN: + irreducible.insert(cur_node); + break; + case ActionsDAG::ActionType::FUNCTION: + LOG_DEBUG(&Poco::Logger::get("debug"), "{} {}", __LINE__, isInjectiveFunction(cur_node)); + if (!isInjectiveFunction(cur_node)) + irreducible.insert(cur_node); + else + for (const auto & child : cur_node->children) + removeInjectiveColumnsFromResultsRecursively(actions, child, irreducible, visited); + break; + case ActionsDAG::ActionType::INPUT: + irreducible.insert(cur_node); + break; + } +} + +/// Removes injective functions recursively from result columns until it is no longer possible. +NodeSet removeInjectiveColumnsFromResultsRecursively(ActionsDAGPtr actions) +{ + NodeSet irreducible; + NodeSet visited; + + for (const auto & node : actions->getOutputs()) + removeInjectiveColumnsFromResultsRecursively(actions, node, irreducible, visited); + + LOG_DEBUG(&Poco::Logger::get("debug"), "irreducible nodes:"); + for (const auto & node : irreducible) + print_node(node); + + return irreducible; +} + +bool allOutputsCovered( + const ActionsDAGPtr & partition_actions, + const NodeSet & irreducible_nodes, + const MatchedTrees::Matches & matches, + const ActionsDAG::Node * cur_node, + NodeMap & visited) +{ + if (visited.contains(cur_node)) + return visited[cur_node]; + + auto has_match_in_group_by_actions = [&irreducible_nodes, &matches, &cur_node]() + { + if (matches.contains(cur_node)) { - const auto & arg1 = func_node.children.front(); - const auto & arg2 = func_node.children.back(); - LOG_DEBUG(&Poco::Logger::get("debug"), "{} {} {}", arg1->type, arg1->result_name, arg2->type); - if (arg1->type == ActionsDAG::ActionType::INPUT && arg1->result_name == gb_keys[0] - && arg2->type == ActionsDAG::ActionType::COLUMN && typeid_cast(arg2->column.get())) - return true; + if (const auto * node_in_gb_actions = matches.at(cur_node).node; + node_in_gb_actions && node_in_gb_actions->type == cur_node->type) + { + return irreducible_nodes.contains(node_in_gb_actions); + } + } + return false; + }; + + bool res = has_match_in_group_by_actions(); + if (!res) + { + switch (cur_node->type) + { + case ActionsDAG::ActionType::ALIAS: + assert(cur_node->children.size() == 1); + res = allOutputsCovered(partition_actions, irreducible_nodes, matches, cur_node->children.at(0), visited); + break; + case ActionsDAG::ActionType::ARRAY_JOIN: + break; + case ActionsDAG::ActionType::COLUMN: + /// Constants doesn't matter, so let's always consider them matched. + res = true; + break; + case ActionsDAG::ActionType::FUNCTION: + res = true; + for (const auto & child : cur_node->children) + res &= allOutputsCovered(partition_actions, irreducible_nodes, matches, child, visited); + break; + case ActionsDAG::ActionType::INPUT: + break; } } + print_node(cur_node); + LOG_DEBUG(&Poco::Logger::get("debug"), "res={}", res); + visited[cur_node] = res; + return res; +} - return false; +bool allOutputsCovered(ActionsDAGPtr partition_actions, const NodeSet & irreducible_nodes, const MatchedTrees::Matches & matches) +{ + NodeMap visited; + + bool res = true; + for (const auto & node : partition_actions->getOutputs()) + if (node->type != ActionsDAG::ActionType::INPUT) + res &= allOutputsCovered(partition_actions, irreducible_nodes, matches, node, visited); + return res; +} + +bool isPartitionKeySuitsGroupByKey(const ReadFromMergeTree & reading, ActionsDAGPtr group_by_actions, const AggregatingStep & aggregating) +{ + /// 0. Partition key columns should be a subset of group by key columns. + /// 1. Optimization is applicable if partition by expression is a deterministic function of col1, ..., coln and group by keys are injective functions of some of col1, ..., coln. + + if (aggregating.isGroupingSets() || group_by_actions->hasArrayJoin() || group_by_actions->hasStatefulFunctions()) + return false; + + /// Check that PK columns is a subset of GBK columns. + const auto partition_actions = reading.getStorageMetadata()->getPartitionKey().expression->getActionsDAG().clone(); + + /// We are interested only in calculations required to obtain group by keys. + group_by_actions->removeUnusedActions(aggregating.getParams().keys); + const auto & gb_keys = group_by_actions->getRequiredColumnsNames(); + + LOG_DEBUG(&Poco::Logger::get("debug"), "group by req cols: {}", fmt::join(gb_keys, ", ")); + LOG_DEBUG(&Poco::Logger::get("debug"), "partition by cols: {}", fmt::join(partition_actions->getRequiredColumnsNames(), ", ")); + + for (const auto & col : partition_actions->getRequiredColumnsNames()) + if (std::ranges::find(gb_keys, col) == gb_keys.end()) + return false; + + /* /// PK is always a deterministic expression without constants. No need to check. */ + + /* /// We will work only with subexpression that depends on partition key columns. */ + LOG_DEBUG(&Poco::Logger::get("debug"), "group by actions before:\n{}", group_by_actions->dumpDAG()); + LOG_DEBUG(&Poco::Logger::get("debug"), "partition by actions before:\n{}", partition_actions->dumpDAG()); + + LOG_DEBUG(&Poco::Logger::get("debug"), "group by actions after:\n{}", group_by_actions->dumpDAG()); + LOG_DEBUG(&Poco::Logger::get("debug"), "partition by actions after:\n{}", partition_actions->dumpDAG()); + + /// For cases like `partition by col + group by col+1` or `partition by hash(col) + group by hash(col)` + const auto irreducibe_nodes = removeInjectiveColumnsFromResultsRecursively(group_by_actions); + + const auto matches = matchTrees(*group_by_actions, *partition_actions); + LOG_DEBUG(&Poco::Logger::get("debug"), "matches:"); + for (const auto & match : matches) + { + if (match.first) + print_node(match.first); + if (match.second.node) + print_node(match.second.node); + LOG_DEBUG(&Poco::Logger::get("debug"), "----------------"); + } + + const bool res = allOutputsCovered(partition_actions, irreducibe_nodes, matches); + LOG_DEBUG(&Poco::Logger::get("debug"), "result={}", res); + return res; } } @@ -50,15 +228,26 @@ size_t tryAggregatePartitionsIndependently(QueryPlan::Node * node, QueryPlan::No return 0; const auto * expression_node = node->children.front(); - if (expression_node->children.size() != 1 || !typeid_cast(expression_node->step.get())) + const auto * expression_step = typeid_cast(expression_node->step.get()); + if (expression_node->children.size() != 1 || !expression_step) return 0; auto * reading_step = expression_node->children.front()->step.get(); + + if (const auto * filter = typeid_cast(reading_step)) + { + const auto * filter_node = expression_node->children.front(); + if (filter_node->children.size() != 1 || !filter_node->children.front()->step) + return 0; + reading_step = filter_node->children.front()->step.get(); + } + auto * reading = typeid_cast(reading_step); if (!reading) return 0; - if (!reading->willOutputEachPartitionThroughSeparatePort() && isPartitionKeySuitsGroupByKey(*reading, *aggregating_step)) + if (!reading->willOutputEachPartitionThroughSeparatePort() + && isPartitionKeySuitsGroupByKey(*reading, expression_step->getExpression()->clone(), *aggregating_step)) { if (reading->requestOutputEachPartitionThroughSeparatePort()) aggregating_step->skipMerging(); diff --git a/tests/performance/aggregation_by_partitions.xml b/tests/performance/aggregation_by_partitions.xml index e24f589508e..7ef4c7742c1 100644 --- a/tests/performance/aggregation_by_partitions.xml +++ b/tests/performance/aggregation_by_partitions.xml @@ -1,6 +1,8 @@ - + 1 + 1 + 4096 0 4096 diff --git a/tests/queries/0_stateless/02521_aggregation_by_partitions.reference b/tests/queries/0_stateless/02521_aggregation_by_partitions.reference index 3e396bafa38..85870f54483 100644 --- a/tests/queries/0_stateless/02521_aggregation_by_partitions.reference +++ b/tests/queries/0_stateless/02521_aggregation_by_partitions.reference @@ -195,3 +195,9 @@ ExpressionTransform × 16 ExpressionTransform × 2 MergeTreeInOrder × 2 0 → 1 1000000 +Skip merging: 1 +Skip merging: 1 +Skip merging: 0 +Skip merging: 1 +Skip merging: 0 +Skip merging: 1 diff --git a/tests/queries/0_stateless/02521_aggregation_by_partitions.sql b/tests/queries/0_stateless/02521_aggregation_by_partitions.sql index 870116d4a23..c788c6e238a 100644 --- a/tests/queries/0_stateless/02521_aggregation_by_partitions.sql +++ b/tests/queries/0_stateless/02521_aggregation_by_partitions.sql @@ -1,4 +1,7 @@ set max_threads = 16; +set allow_aggregate_partitions_independently = 1; +set force_aggregate_partitions_independently = 1; +set allow_experimental_projection_optimization = 0; create table t1(a UInt32) engine=MergeTree order by tuple() partition by a % 4; @@ -86,3 +89,64 @@ explain pipeline select a from t6 group by a settings read_in_order_two_level_me select count() from (select throwIf(count() != 2) from t6 group by a); drop table t6; + +create table t7(a UInt32) engine=MergeTree order by a partition by intDiv(a, 2); + +insert into t7 select number from numbers_mt(100); + +select replaceRegexpOne(explain, '^[ ]*(.*)', '\\1') from ( + explain actions=1 select intDiv(a, 2) as a1 from t7 group by a1 +) where explain like '%Skip merging: %'; + +drop table t7; + +create table t8(a UInt32) engine=MergeTree order by a partition by intDiv(a, 2) * 2 + 1; + +insert into t8 select number from numbers_mt(100); + +select replaceRegexpOne(explain, '^[ ]*(.*)', '\\1') from ( + explain actions=1 select intDiv(a, 2) + 1 as a1 from t8 group by a1 +) where explain like '%Skip merging: %'; + +drop table t8; + +create table t9(a UInt32) engine=MergeTree order by a partition by intDiv(a, 2); + +insert into t9 select number from numbers_mt(100); + +select replaceRegexpOne(explain, '^[ ]*(.*)', '\\1') from ( + explain actions=1 select intDiv(a, 3) as a1 from t9 group by a1 +) where explain like '%Skip merging: %'; + +drop table t9; + +create table t10(a UInt32, b UInt32) engine=MergeTree order by a partition by (intDiv(a, 2), intDiv(b, 3)); + +insert into t10 select number, number from numbers_mt(100); + +select replaceRegexpOne(explain, '^[ ]*(.*)', '\\1') from ( + explain actions=1 select intDiv(a, 2) + 1 as a1, intDiv(b, 3) as b1 from t10 group by a1, b1, pi() +) where explain like '%Skip merging: %'; + +drop table t10; + +-- multiplication by 2 is not injective, so optimization is not applicable +create table t11(a UInt32, b UInt32) engine=MergeTree order by a partition by (intDiv(a, 2), intDiv(b, 3)); + +insert into t11 select number, number from numbers_mt(100); + +select replaceRegexpOne(explain, '^[ ]*(.*)', '\\1') from ( + explain actions=1 select intDiv(a, 2) + 1 as a1, intDiv(b, 3) * 2 as b1 from t11 group by a1, b1, pi() +) where explain like '%Skip merging: %'; + +drop table t11; + +create table t12(a UInt32, b UInt32) engine=MergeTree order by a partition by a % 16; + +insert into t12 select number, number from numbers_mt(100); + +select replaceRegexpOne(explain, '^[ ]*(.*)', '\\1') from ( + explain actions=1 select a, b from t12 group by a, b, pi() +) where explain like '%Skip merging: %'; + +drop table t12;