From 7990c8cbad00f4287374573c6020208b21d3b556 Mon Sep 17 00:00:00 2001 From: Anton Popov Date: Thu, 12 Nov 2020 22:50:01 +0300 Subject: [PATCH] fix optimization with 'optimize_aggregators_of_group_by_keys' and joins --- .../AggregateFunctionOfGroupByKeysVisitor.h | 26 +++--- src/Interpreters/TreeOptimizer.cpp | 81 +++++-------------- ...egate_functions_of_key_with_join.reference | 1 + ...1_aggregate_functions_of_key_with_join.sql | 5 ++ 4 files changed, 36 insertions(+), 77 deletions(-) create mode 100644 tests/queries/0_stateless/01561_aggregate_functions_of_key_with_join.reference create mode 100644 tests/queries/0_stateless/01561_aggregate_functions_of_key_with_join.sql diff --git a/src/Interpreters/AggregateFunctionOfGroupByKeysVisitor.h b/src/Interpreters/AggregateFunctionOfGroupByKeysVisitor.h index 6b903ec45cf..587baa660cb 100644 --- a/src/Interpreters/AggregateFunctionOfGroupByKeysVisitor.h +++ b/src/Interpreters/AggregateFunctionOfGroupByKeysVisitor.h @@ -20,8 +20,8 @@ struct KeepAggregateFunctionMatcher { struct Data { - std::unordered_set & group_by_keys; - bool & keep_aggregator; + const NameSet & group_by_keys; + bool keep_aggregator; }; using Visitor = InDepthNodeVisitor; @@ -33,7 +33,7 @@ struct KeepAggregateFunctionMatcher static void visit(ASTFunction & function_node, Data & data) { - if ((function_node.arguments->children).empty()) + if (function_node.arguments->children.empty()) { data.keep_aggregator = true; return; @@ -47,12 +47,9 @@ struct KeepAggregateFunctionMatcher static void visit(ASTIdentifier & ident, Data & data) { - if (!data.group_by_keys.count(ident.shortName())) - { - /// if variable of a function is not in GROUP BY keys, this function should not be deleted + /// if variable of a function is not in GROUP BY keys, this function should not be deleted + if (!data.group_by_keys.count(ident.getColumnName())) data.keep_aggregator = true; - return; - } } static void visit(const ASTPtr & ast, Data & data) @@ -75,21 +72,21 @@ struct KeepAggregateFunctionMatcher } }; -using KeepAggregateFunctionVisitor = InDepthNodeVisitor; +using KeepAggregateFunctionVisitor = KeepAggregateFunctionMatcher::Visitor; class SelectAggregateFunctionOfGroupByKeysMatcher { public: struct Data { - std::unordered_set & group_by_keys; + const NameSet & group_by_keys; }; static bool needChildVisit(const ASTPtr & node, const ASTPtr &) { /// Don't descent into table functions and subqueries and special case for ArrayJoin. - return !node->as() && - !(node->as() || node->as() || node->as()); + return !node->as() && !node->as() + && !node->as() && !node->as(); } static void visit(ASTPtr & ast, Data & data) @@ -99,12 +96,11 @@ public: if (function_node && (function_node->name == "min" || function_node->name == "max" || function_node->name == "any" || function_node->name == "anyLast")) { - bool keep_aggregator = false; - KeepAggregateFunctionVisitor::Data keep_data{data.group_by_keys, keep_aggregator}; + KeepAggregateFunctionVisitor::Data keep_data{data.group_by_keys, false}; KeepAggregateFunctionVisitor(keep_data).visit(function_node->arguments); /// Place argument of an aggregate function instead of function - if (!keep_aggregator) + if (!keep_data.keep_aggregator) { String alias = function_node->alias; ast = (function_node->arguments->children[0])->clone(); diff --git a/src/Interpreters/TreeOptimizer.cpp b/src/Interpreters/TreeOptimizer.cpp index 61ca933dd53..b0f9ef187f1 100644 --- a/src/Interpreters/TreeOptimizer.cpp +++ b/src/Interpreters/TreeOptimizer.cpp @@ -177,43 +177,21 @@ void optimizeGroupBy(ASTSelectQuery * select_query, const NameSet & source_colum struct GroupByKeysInfo { - std::unordered_set key_names; ///set of keys' short names - bool has_identifier = false; + NameSet key_names; ///set of keys' short names bool has_function = false; - bool has_possible_collision = false; }; -GroupByKeysInfo getGroupByKeysInfo(ASTs & group_keys) +GroupByKeysInfo getGroupByKeysInfo(const ASTs & group_by_keys) { GroupByKeysInfo data; - ///filling set with short names of keys - for (auto & group_key : group_keys) + /// filling set with short names of keys + for (auto & group_key : group_by_keys) { if (group_key->as()) data.has_function = true; - if (auto * group_key_ident = group_key->as()) - { - data.has_identifier = true; - if (data.key_names.count(group_key_ident->shortName())) - { - ///There may be a collision between different tables having similar variables. - ///Due to the fact that we can't track these conflicts yet, - ///it's better to disable some optimizations to avoid elimination necessary keys. - data.has_possible_collision = true; - } - - data.key_names.insert(group_key_ident->shortName()); - } - else if (auto * group_key_func = group_key->as()) - { - data.key_names.insert(group_key_func->getColumnName()); - } - else - { - data.key_names.insert(group_key->getColumnName()); - } + data.key_names.insert(group_key->getColumnName()); } return data; @@ -225,47 +203,28 @@ void optimizeGroupByFunctionKeys(ASTSelectQuery * select_query) if (!select_query->groupBy()) return; - auto grp_by = select_query->groupBy(); - auto & group_keys = grp_by->children; + auto group_by = select_query->groupBy(); + const auto & group_by_keys = group_by->children; ASTs modified; ///result - GroupByKeysInfo group_by_keys_data = getGroupByKeysInfo(group_keys); + GroupByKeysInfo group_by_keys_data = getGroupByKeysInfo(group_by_keys); - if (!group_by_keys_data.has_function || group_by_keys_data.has_possible_collision) + if (!group_by_keys_data.has_function) return; GroupByFunctionKeysVisitor::Data visitor_data{group_by_keys_data.key_names}; - GroupByFunctionKeysVisitor(visitor_data).visit(grp_by); + GroupByFunctionKeysVisitor(visitor_data).visit(group_by); - modified.reserve(group_keys.size()); + modified.reserve(group_by_keys.size()); - ///filling the result - for (auto & group_key : group_keys) - { - if (auto * group_key_func = group_key->as()) - { - if (group_by_keys_data.key_names.count(group_key_func->getColumnName())) - modified.push_back(group_key); + /// filling the result + for (auto & group_key : group_by_keys) + if (group_by_keys_data.key_names.count(group_key->getColumnName())) + modified.push_back(group_key); - continue; - } - if (auto * group_key_ident = group_key->as()) - { - if (group_by_keys_data.key_names.count(group_key_ident->shortName())) - modified.push_back(group_key); - - continue; - } - else - { - if (group_by_keys_data.key_names.count(group_key->getColumnName())) - modified.push_back(group_key); - } - } - - ///modifying the input - grp_by->children = modified; + /// modifying the input + group_by->children = modified; } /// Eliminates min/max/any-aggregators of functions of GROUP BY keys @@ -274,10 +233,8 @@ void optimizeAggregateFunctionsOfGroupByKeys(ASTSelectQuery * select_query, ASTP if (!select_query->groupBy()) return; - auto grp_by = select_query->groupBy(); - auto & group_keys = grp_by->children; - - GroupByKeysInfo group_by_keys_data = getGroupByKeysInfo(group_keys); + auto & group_by_keys = select_query->groupBy()->children; + GroupByKeysInfo group_by_keys_data = getGroupByKeysInfo(group_by_keys); SelectAggregateFunctionOfGroupByKeysVisitor::Data visitor_data{group_by_keys_data.key_names}; SelectAggregateFunctionOfGroupByKeysVisitor(visitor_data).visit(node); diff --git a/tests/queries/0_stateless/01561_aggregate_functions_of_key_with_join.reference b/tests/queries/0_stateless/01561_aggregate_functions_of_key_with_join.reference new file mode 100644 index 00000000000..9874d6464ab --- /dev/null +++ b/tests/queries/0_stateless/01561_aggregate_functions_of_key_with_join.reference @@ -0,0 +1 @@ +1 2 diff --git a/tests/queries/0_stateless/01561_aggregate_functions_of_key_with_join.sql b/tests/queries/0_stateless/01561_aggregate_functions_of_key_with_join.sql new file mode 100644 index 00000000000..66047fcc1a6 --- /dev/null +++ b/tests/queries/0_stateless/01561_aggregate_functions_of_key_with_join.sql @@ -0,0 +1,5 @@ +SET optimize_aggregators_of_group_by_keys = 1; +SELECT source.key, max(target.key) FROM (SELECT 1 key, 'x' name) source +INNER JOIN (SELECT 2 key, 'x' name) target +ON source.name = target.name +GROUP BY source.key;