diff --git a/docs/en/sql-reference/statements/select/group-by.md b/docs/en/sql-reference/statements/select/group-by.md index ac02e9ab5a1..2df8581c447 100644 --- a/docs/en/sql-reference/statements/select/group-by.md +++ b/docs/en/sql-reference/statements/select/group-by.md @@ -243,6 +243,54 @@ If `max_rows_to_group_by` and `group_by_overflow_mode = 'any'` are not used, all You can use `WITH TOTALS` in subqueries, including subqueries in the [JOIN](../../../sql-reference/statements/select/join.md) clause (in this case, the respective total values are combined). +## GROUP BY ALL + +`GROUP BY ALL` is equivalent to listing all the SELECT-ed expressions that are not aggregate functions. + +For example: + +``` sql +SELECT + a * 2, + b, + count(c), +FROM t +GROUP BY ALL +``` + +is the same as + +``` sql +SELECT + a * 2, + b, + count(c), +FROM t +GROUP BY a * 2, b +``` + +For a special case that if there is a function having both aggregate functions and other fields as its arguments, the `GROUP BY` keys will contain the maximum non-aggregate fields we can extract from it. + +For example: + +``` sql +SELECT + substring(a, 4, 2), + substring(substring(a, 1, 2), 1, count(b)) +FROM t +GROUP BY ALL +``` + +is the same as + +``` sql +SELECT + substring(a, 4, 2), + substring(substring(a, 1, 2), 1, count(b)) +FROM t +GROUP BY substring(a, 4, 2), substring(a, 1, 2) +``` + ## Examples Example: diff --git a/docs/zh/sql-reference/statements/select/group-by.md b/docs/zh/sql-reference/statements/select/group-by.md index 90b3c7660ee..31c1649bc30 100644 --- a/docs/zh/sql-reference/statements/select/group-by.md +++ b/docs/zh/sql-reference/statements/select/group-by.md @@ -77,6 +77,54 @@ sidebar_label: GROUP BY 您可以使用 `WITH TOTALS` 在子查询中,包括在子查询 [JOIN](../../../sql-reference/statements/select/join.md) 子句(在这种情况下,将各自的总值合并)。 +## GROUP BY ALL {#group-by-all} + +`GROUP BY ALL` 相当于对所有被查询的并且不被聚合函数使用的字段进行`GROUP BY`。 + +例如 + +``` sql +SELECT + a * 2, + b, + count(c), +FROM t +GROUP BY ALL +``` + +效果等同于 + +``` sql +SELECT + a * 2, + b, + count(c), +FROM t +GROUP BY a * 2, b +``` + +对于一种特殊情况,如果一个 function 的参数中同时有聚合函数和其他字段,会对参数中能提取的最大非聚合字段进行`GROUP BY`。 + +例如: + +``` sql +SELECT + substring(a, 4, 2), + substring(substring(a, 1, 2), 1, count(b)) +FROM t +GROUP BY ALL +``` + +效果等同于 + +``` sql +SELECT + substring(a, 4, 2), + substring(substring(a, 1, 2), 1, count(b)) +FROM t +GROUP BY substring(a, 4, 2), substring(a, 1, 2) +``` + ## 例子 {#examples} 示例: diff --git a/src/Analyzer/Passes/QueryAnalysisPass.cpp b/src/Analyzer/Passes/QueryAnalysisPass.cpp index 138ff721f99..41785accc82 100644 --- a/src/Analyzer/Passes/QueryAnalysisPass.cpp +++ b/src/Analyzer/Passes/QueryAnalysisPass.cpp @@ -67,6 +67,8 @@ #include #include +#include + namespace DB { @@ -1100,6 +1102,10 @@ private: static void validateJoinTableExpressionWithoutAlias(const QueryTreeNodePtr & join_node, const QueryTreeNodePtr & table_expression_node, IdentifierResolveScope & scope); + static void expandGroupByAll(QueryNode & query_tree_node_typed); + + static std::pair recursivelyCollectMaxOrdinaryExpressions(QueryTreeNodePtr & node, QueryTreeNodes & into); + /// Resolve identifier functions static QueryTreeNodePtr tryResolveTableIdentifierFromDatabaseCatalog(const Identifier & table_identifier, ContextPtr context); @@ -1929,6 +1935,68 @@ void QueryAnalyzer::validateJoinTableExpressionWithoutAlias(const QueryTreeNodeP scope.scope_node->formatASTForErrorMessage()); } +std::pair QueryAnalyzer::recursivelyCollectMaxOrdinaryExpressions(QueryTreeNodePtr & node, QueryTreeNodes & into) +{ + checkStackSize(); + + if (node->as()) + { + into.push_back(node); + return {false, 1}; + } + + auto * function = node->as(); + + if (!function) + return {false, 0}; + + if (function->isAggregateFunction()) + return {true, 0}; + + UInt64 pushed_children = 0; + bool has_aggregate = false; + + for (auto & child : function->getArguments().getNodes()) + { + auto [child_has_aggregate, child_pushed_children] = recursivelyCollectMaxOrdinaryExpressions(child, into); + has_aggregate |= child_has_aggregate; + pushed_children += child_pushed_children; + } + + /// The current function is not aggregate function and there is no aggregate function in its arguments, + /// so use the current function to replace its arguments + if (!has_aggregate) + { + for (UInt64 i = 0; i < pushed_children; i++) + into.pop_back(); + + into.push_back(node); + pushed_children = 1; + } + + return {has_aggregate, pushed_children}; +} + +/** Expand GROUP BY ALL by extracting all the SELECT-ed expressions that are not aggregate functions. + * + * For a special case that if there is a function having both aggregate functions and other fields as its arguments, + * the `GROUP BY` keys will contain the maximum non-aggregate fields we can extract from it. + * + * Example: + * SELECT substring(a, 4, 2), substring(substring(a, 1, 2), 1, count(b)) FROM t GROUP BY ALL + * will expand as + * SELECT substring(a, 4, 2), substring(substring(a, 1, 2), 1, count(b)) FROM t GROUP BY substring(a, 4, 2), substring(a, 1, 2) + */ +void QueryAnalyzer::expandGroupByAll(QueryNode & query_tree_node_typed) +{ + auto & group_by_nodes = query_tree_node_typed.getGroupBy().getNodes(); + auto & projection_list = query_tree_node_typed.getProjection(); + + for (auto & node : projection_list.getNodes()) + recursivelyCollectMaxOrdinaryExpressions(node, group_by_nodes); + +} + /// Resolve identifier functions implementation @@ -6006,6 +6074,9 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier node->removeAlias(); } + if (query_node_typed.isGroupByAll()) + expandGroupByAll(query_node_typed); + /** Validate aggregates * * 1. Check that there are no aggregate functions and GROUPING function in JOIN TREE, WHERE, PREWHERE, in another aggregate functions. diff --git a/src/Analyzer/QueryNode.cpp b/src/Analyzer/QueryNode.cpp index c5bbc193544..d31a3660336 100644 --- a/src/Analyzer/QueryNode.cpp +++ b/src/Analyzer/QueryNode.cpp @@ -54,6 +54,9 @@ void QueryNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, s if (is_group_by_with_totals) buffer << ", is_group_by_with_totals: " << is_group_by_with_totals; + if (is_group_by_all) + buffer << ", is_group_by_all: " << is_group_by_all; + std::string group_by_type; if (is_group_by_with_rollup) group_by_type = "rollup"; @@ -117,7 +120,7 @@ void QueryNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, s getWhere()->dumpTreeImpl(buffer, format_state, indent + 4); } - if (hasGroupBy()) + if (!is_group_by_all && hasGroupBy()) { buffer << '\n' << std::string(indent + 2, ' ') << "GROUP BY\n"; getGroupBy().dumpTreeImpl(buffer, format_state, indent + 4); @@ -198,7 +201,8 @@ bool QueryNode::isEqualImpl(const IQueryTreeNode & rhs) const is_group_by_with_totals == rhs_typed.is_group_by_with_totals && is_group_by_with_rollup == rhs_typed.is_group_by_with_rollup && is_group_by_with_cube == rhs_typed.is_group_by_with_cube && - is_group_by_with_grouping_sets == rhs_typed.is_group_by_with_grouping_sets; + is_group_by_with_grouping_sets == rhs_typed.is_group_by_with_grouping_sets && + is_group_by_all == rhs_typed.is_group_by_all; } void QueryNode::updateTreeHashImpl(HashState & state) const @@ -226,6 +230,7 @@ void QueryNode::updateTreeHashImpl(HashState & state) const state.update(is_group_by_with_rollup); state.update(is_group_by_with_cube); state.update(is_group_by_with_grouping_sets); + state.update(is_group_by_all); if (constant_value) { @@ -251,6 +256,7 @@ QueryTreeNodePtr QueryNode::cloneImpl() const result_query_node->is_group_by_with_rollup = is_group_by_with_rollup; result_query_node->is_group_by_with_cube = is_group_by_with_cube; result_query_node->is_group_by_with_grouping_sets = is_group_by_with_grouping_sets; + result_query_node->is_group_by_all = is_group_by_all; result_query_node->cte_name = cte_name; result_query_node->projection_columns = projection_columns; result_query_node->constant_value = constant_value; @@ -267,6 +273,7 @@ ASTPtr QueryNode::toASTImpl() const select_query->group_by_with_rollup = is_group_by_with_rollup; select_query->group_by_with_cube = is_group_by_with_cube; select_query->group_by_with_grouping_sets = is_group_by_with_grouping_sets; + select_query->group_by_all = is_group_by_all; if (hasWith()) select_query->setExpression(ASTSelectQuery::Expression::WITH, getWith().toAST()); @@ -283,7 +290,7 @@ ASTPtr QueryNode::toASTImpl() const if (getWhere()) select_query->setExpression(ASTSelectQuery::Expression::WHERE, getWhere()->toAST()); - if (hasGroupBy()) + if (!is_group_by_all && hasGroupBy()) select_query->setExpression(ASTSelectQuery::Expression::GROUP_BY, getGroupBy().toAST()); if (hasHaving()) diff --git a/src/Analyzer/QueryNode.h b/src/Analyzer/QueryNode.h index 1bb381c95c9..5eb70f168ec 100644 --- a/src/Analyzer/QueryNode.h +++ b/src/Analyzer/QueryNode.h @@ -176,6 +176,18 @@ public: is_group_by_with_grouping_sets = is_group_by_with_grouping_sets_value; } + /// Returns true, if query node has GROUP BY ALL modifier, false otherwise + bool isGroupByAll() const + { + return is_group_by_all; + } + + /// Set query node GROUP BY ALL modifier value + void setIsGroupByAll(bool is_group_by_all_value) + { + is_group_by_all = is_group_by_all_value; + } + /// Returns true if query node WITH section is not empty, false otherwise bool hasWith() const { @@ -580,6 +592,7 @@ private: bool is_group_by_with_rollup = false; bool is_group_by_with_cube = false; bool is_group_by_with_grouping_sets = false; + bool is_group_by_all = false; std::string cte_name; NamesAndTypes projection_columns; diff --git a/src/Analyzer/QueryTreeBuilder.cpp b/src/Analyzer/QueryTreeBuilder.cpp index 51745d820e7..01ecd4ece30 100644 --- a/src/Analyzer/QueryTreeBuilder.cpp +++ b/src/Analyzer/QueryTreeBuilder.cpp @@ -215,6 +215,7 @@ QueryTreeNodePtr QueryTreeBuilder::buildSelectExpression(const ASTPtr & select_q current_query_tree->setIsGroupByWithCube(select_query_typed.group_by_with_cube); current_query_tree->setIsGroupByWithRollup(select_query_typed.group_by_with_rollup); current_query_tree->setIsGroupByWithGroupingSets(select_query_typed.group_by_with_grouping_sets); + current_query_tree->setIsGroupByAll(select_query_typed.group_by_all); current_query_tree->setOriginalAST(select_query); auto select_settings = select_query_typed.settings(); diff --git a/src/Interpreters/TreeRewriter.cpp b/src/Interpreters/TreeRewriter.cpp index da12dccd8d8..30fab527ac5 100644 --- a/src/Interpreters/TreeRewriter.cpp +++ b/src/Interpreters/TreeRewriter.cpp @@ -1,8 +1,8 @@ #include #include + #include #include - #include #include @@ -45,10 +45,10 @@ #include #include #include -#include #include #include +#include #include @@ -784,6 +784,67 @@ void collectJoinedColumns(TableJoin & analyzed_join, ASTTableJoin & table_join, } } +std::pair recursivelyCollectMaxOrdinaryExpressions(const ASTPtr & expr, ASTExpressionList & into) +{ + checkStackSize(); + + if (expr->as()) + { + into.children.push_back(expr); + return {false, 1}; + } + + auto * function = expr->as(); + + if (!function) + return {false, 0}; + + if (AggregateUtils::isAggregateFunction(*function)) + return {true, 0}; + + UInt64 pushed_children = 0; + bool has_aggregate = false; + + for (const auto & child : function->arguments->children) + { + auto [child_has_aggregate, child_pushed_children] = recursivelyCollectMaxOrdinaryExpressions(child, into); + has_aggregate |= child_has_aggregate; + pushed_children += child_pushed_children; + } + + /// The current function is not aggregate function and there is no aggregate function in its arguments, + /// so use the current function to replace its arguments + if (!has_aggregate) + { + for (UInt64 i = 0; i < pushed_children; i++) + into.children.pop_back(); + + into.children.push_back(expr); + pushed_children = 1; + } + + return {has_aggregate, pushed_children}; +} + +/** Expand GROUP BY ALL by extracting all the SELECT-ed expressions that are not aggregate functions. + * + * For a special case that if there is a function having both aggregate functions and other fields as its arguments, + * the `GROUP BY` keys will contain the maximum non-aggregate fields we can extract from it. + * + * Example: + * SELECT substring(a, 4, 2), substring(substring(a, 1, 2), 1, count(b)) FROM t GROUP BY ALL + * will expand as + * SELECT substring(a, 4, 2), substring(substring(a, 1, 2), 1, count(b)) FROM t GROUP BY substring(a, 4, 2), substring(a, 1, 2) + */ +void expandGroupByAll(ASTSelectQuery * select_query) +{ + auto group_expression_list = std::make_shared(); + + for (const auto & expr : select_query->select()->children) + recursivelyCollectMaxOrdinaryExpressions(expr, *group_expression_list); + + select_query->setExpression(ASTSelectQuery::Expression::GROUP_BY, group_expression_list); +} std::vector getAggregates(ASTPtr & query, const ASTSelectQuery & select_query) { @@ -1276,6 +1337,10 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect( normalize(query, result.aliases, all_source_columns_set, select_options.ignore_alias, settings, /* allow_self_aliases = */ true, getContext()); + // expand GROUP BY ALL + if (select_query->group_by_all) + expandGroupByAll(select_query); + /// Remove unneeded columns according to 'required_result_columns'. /// Leave all selected columns in case of DISTINCT; columns that contain arrayJoin function inside. /// Must be after 'normalizeTree' (after expanding aliases, for aliases not get lost) diff --git a/src/Parsers/ASTSelectQuery.cpp b/src/Parsers/ASTSelectQuery.cpp index 76849653b4e..e0e3b1a90c1 100644 --- a/src/Parsers/ASTSelectQuery.cpp +++ b/src/Parsers/ASTSelectQuery.cpp @@ -93,7 +93,7 @@ void ASTSelectQuery::formatImpl(const FormatSettings & s, FormatState & state, F where()->formatImpl(s, state, frame); } - if (groupBy()) + if (!group_by_all && groupBy()) { s.ostr << (s.hilite ? hilite_keyword : "") << s.nl_or_ws << indent_str << "GROUP BY" << (s.hilite ? hilite_none : ""); if (!group_by_with_grouping_sets) @@ -104,6 +104,9 @@ void ASTSelectQuery::formatImpl(const FormatSettings & s, FormatState & state, F } } + if (group_by_all) + s.ostr << (s.hilite ? hilite_keyword : "") << s.nl_or_ws << indent_str << "GROUP BY ALL" << (s.hilite ? hilite_none : ""); + if (group_by_with_rollup) s.ostr << (s.hilite ? hilite_keyword : "") << s.nl_or_ws << indent_str << (s.one_line ? "" : " ") << "WITH ROLLUP" << (s.hilite ? hilite_none : ""); diff --git a/src/Parsers/ASTSelectQuery.h b/src/Parsers/ASTSelectQuery.h index 5e3af545f12..3db8524c8b6 100644 --- a/src/Parsers/ASTSelectQuery.h +++ b/src/Parsers/ASTSelectQuery.h @@ -82,6 +82,7 @@ public: ASTPtr clone() const override; bool distinct = false; + bool group_by_all = false; bool group_by_with_totals = false; bool group_by_with_rollup = false; bool group_by_with_cube = false; diff --git a/src/Parsers/ParserSelectQuery.cpp b/src/Parsers/ParserSelectQuery.cpp index cf335270734..201cd750af8 100644 --- a/src/Parsers/ParserSelectQuery.cpp +++ b/src/Parsers/ParserSelectQuery.cpp @@ -195,6 +195,8 @@ bool ParserSelectQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) select_query->group_by_with_cube = true; else if (s_grouping_sets.ignore(pos, expected)) select_query->group_by_with_grouping_sets = true; + else if (s_all.ignore(pos, expected)) + select_query->group_by_all = true; if ((select_query->group_by_with_rollup || select_query->group_by_with_cube || select_query->group_by_with_grouping_sets) && !open_bracket.ignore(pos, expected)) @@ -205,7 +207,7 @@ bool ParserSelectQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) if (!grouping_sets_list.parse(pos, group_expression_list, expected)) return false; } - else + else if (!select_query->group_by_all) { if (!exp_list.parse(pos, group_expression_list, expected)) return false; diff --git a/tests/queries/0_stateless/02459_group_by_all.reference b/tests/queries/0_stateless/02459_group_by_all.reference new file mode 100644 index 00000000000..7c5ccbd8fbf --- /dev/null +++ b/tests/queries/0_stateless/02459_group_by_all.reference @@ -0,0 +1,44 @@ +abc1 1 +abc2 1 +abc3 1 +abc4 1 +abc 4 +abc ab +abc ab +abc ab +abc bc +abc bc +abc a +abc a +abc a +abc a +abc a +abc a +abc a +abc a +1 abc a +1 abc a +1 abc a +1 abc a +abc1 1 +abc2 1 +abc3 1 +abc4 1 +abc 4 +abc ab +abc ab +abc ab +abc bc +abc bc +abc a +abc a +abc a +abc a +abc a +abc a +abc a +abc a +1 abc a +1 abc a +1 abc a +1 abc a diff --git a/tests/queries/0_stateless/02459_group_by_all.sql b/tests/queries/0_stateless/02459_group_by_all.sql new file mode 100644 index 00000000000..4f08ee331a4 --- /dev/null +++ b/tests/queries/0_stateless/02459_group_by_all.sql @@ -0,0 +1,35 @@ +DROP TABLE IF EXISTS group_by_all; + +CREATE TABLE group_by_all +( + a String, + b int, + c int +) +engine = Memory; + +insert into group_by_all values ('abc1', 1, 1), ('abc2', 1, 1), ('abc3', 1, 1), ('abc4', 1, 1); + +select a, count(b) from group_by_all group by all order by a; +select substring(a, 1, 3), count(b) from group_by_all group by all; +select substring(a, 1, 3), substring(substring(a, 1, 2), 1, count(b)) from group_by_all group by all; +select substring(a, 1, 3), substring(substring(a, 1, 2), c, count(b)) from group_by_all group by all; +select substring(a, 1, 3), substring(substring(a, c, 2), c, count(b)) from group_by_all group by all; +select substring(a, 1, 3), substring(substring(a, c + 1, 2), 1, count(b)) from group_by_all group by all; +select substring(a, 1, 3), substring(substring(a, c + 1, 2), c, count(b)) from group_by_all group by all; +select substring(a, 1, 3), substring(substring(substring(a, c, count(b)), 1, count(b)), 1, count(b)) from group_by_all group by all; +select substring(a, 1, 3), substring(a, 1, count(b)) from group_by_all group by all; +select count(b) AS len, substring(a, 1, 3), substring(a, 1, len) from group_by_all group by all; + +SET allow_experimental_analyzer = 1; + +select a, count(b) from group_by_all group by all order by a; +select substring(a, 1, 3), count(b) from group_by_all group by all; +select substring(a, 1, 3), substring(substring(a, 1, 2), 1, count(b)) from group_by_all group by all; +select substring(a, 1, 3), substring(substring(a, 1, 2), c, count(b)) from group_by_all group by all; +select substring(a, 1, 3), substring(substring(a, c, 2), c, count(b)) from group_by_all group by all; +select substring(a, 1, 3), substring(substring(a, c + 1, 2), 1, count(b)) from group_by_all group by all; +select substring(a, 1, 3), substring(substring(a, c + 1, 2), c, count(b)) from group_by_all group by all; +select substring(a, 1, 3), substring(substring(substring(a, c, count(b)), 1, count(b)), 1, count(b)) from group_by_all group by all; +select substring(a, 1, 3), substring(a, 1, count(b)) from group_by_all group by all; +select count(b) AS len, substring(a, 1, 3), substring(a, 1, len) from group_by_all group by all;