Merge pull request #42265 from taofengliu/group_by_all

Support GROUP BY ALL
This commit is contained in:
Alexey Milovidov 2022-11-22 20:08:23 +01:00 committed by GitHub
commit 25780be0c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 345 additions and 7 deletions

View File

@ -243,6 +243,54 @@ If `max_rows_to_group_by` and `group_by_overflow_mode = 'any'` are not used, all
You can use `WITH TOTALS` in subqueries, including subqueries in the [JOIN](../../../sql-reference/statements/select/join.md) clause (in this case, the respective total values are combined).
## GROUP BY ALL
`GROUP BY ALL` is equivalent to listing all the SELECT-ed expressions that are not aggregate functions.
For example:
``` sql
SELECT
a * 2,
b,
count(c),
FROM t
GROUP BY ALL
```
is the same as
``` sql
SELECT
a * 2,
b,
count(c),
FROM t
GROUP BY a * 2, b
```
For a special case that if there is a function having both aggregate functions and other fields as its arguments, the `GROUP BY` keys will contain the maximum non-aggregate fields we can extract from it.
For example:
``` sql
SELECT
substring(a, 4, 2),
substring(substring(a, 1, 2), 1, count(b))
FROM t
GROUP BY ALL
```
is the same as
``` sql
SELECT
substring(a, 4, 2),
substring(substring(a, 1, 2), 1, count(b))
FROM t
GROUP BY substring(a, 4, 2), substring(a, 1, 2)
```
## Examples
Example:

View File

@ -77,6 +77,54 @@ sidebar_label: GROUP BY
您可以使用 `WITH TOTALS` 在子查询中,包括在子查询 [JOIN](../../../sql-reference/statements/select/join.md) 子句(在这种情况下,将各自的总值合并)。
## GROUP BY ALL {#group-by-all}
`GROUP BY ALL` 相当于对所有被查询的并且不被聚合函数使用的字段进行`GROUP BY`。
例如
``` sql
SELECT
a * 2,
b,
count(c),
FROM t
GROUP BY ALL
```
效果等同于
``` sql
SELECT
a * 2,
b,
count(c),
FROM t
GROUP BY a * 2, b
```
对于一种特殊情况,如果一个 function 的参数中同时有聚合函数和其他字段,会对参数中能提取的最大非聚合字段进行`GROUP BY`。
例如:
``` sql
SELECT
substring(a, 4, 2),
substring(substring(a, 1, 2), 1, count(b))
FROM t
GROUP BY ALL
```
效果等同于
``` sql
SELECT
substring(a, 4, 2),
substring(substring(a, 1, 2), 1, count(b))
FROM t
GROUP BY substring(a, 4, 2), substring(a, 1, 2)
```
## 例子 {#examples}
示例:

View File

@ -67,6 +67,8 @@
#include <Analyzer/InDepthQueryTreeVisitor.h>
#include <Analyzer/QueryTreeBuilder.h>
#include <Common/checkStackSize.h>
namespace DB
{
@ -1100,6 +1102,10 @@ private:
static void validateJoinTableExpressionWithoutAlias(const QueryTreeNodePtr & join_node, const QueryTreeNodePtr & table_expression_node, IdentifierResolveScope & scope);
static void expandGroupByAll(QueryNode & query_tree_node_typed);
static std::pair<bool, UInt64> recursivelyCollectMaxOrdinaryExpressions(QueryTreeNodePtr & node, QueryTreeNodes & into);
/// Resolve identifier functions
static QueryTreeNodePtr tryResolveTableIdentifierFromDatabaseCatalog(const Identifier & table_identifier, ContextPtr context);
@ -1929,6 +1935,68 @@ void QueryAnalyzer::validateJoinTableExpressionWithoutAlias(const QueryTreeNodeP
scope.scope_node->formatASTForErrorMessage());
}
std::pair<bool, UInt64> QueryAnalyzer::recursivelyCollectMaxOrdinaryExpressions(QueryTreeNodePtr & node, QueryTreeNodes & into)
{
checkStackSize();
if (node->as<ColumnNode>())
{
into.push_back(node);
return {false, 1};
}
auto * function = node->as<FunctionNode>();
if (!function)
return {false, 0};
if (function->isAggregateFunction())
return {true, 0};
UInt64 pushed_children = 0;
bool has_aggregate = false;
for (auto & child : function->getArguments().getNodes())
{
auto [child_has_aggregate, child_pushed_children] = recursivelyCollectMaxOrdinaryExpressions(child, into);
has_aggregate |= child_has_aggregate;
pushed_children += child_pushed_children;
}
/// The current function is not aggregate function and there is no aggregate function in its arguments,
/// so use the current function to replace its arguments
if (!has_aggregate)
{
for (UInt64 i = 0; i < pushed_children; i++)
into.pop_back();
into.push_back(node);
pushed_children = 1;
}
return {has_aggregate, pushed_children};
}
/** Expand GROUP BY ALL by extracting all the SELECT-ed expressions that are not aggregate functions.
*
* For a special case that if there is a function having both aggregate functions and other fields as its arguments,
* the `GROUP BY` keys will contain the maximum non-aggregate fields we can extract from it.
*
* Example:
* SELECT substring(a, 4, 2), substring(substring(a, 1, 2), 1, count(b)) FROM t GROUP BY ALL
* will expand as
* SELECT substring(a, 4, 2), substring(substring(a, 1, 2), 1, count(b)) FROM t GROUP BY substring(a, 4, 2), substring(a, 1, 2)
*/
void QueryAnalyzer::expandGroupByAll(QueryNode & query_tree_node_typed)
{
auto & group_by_nodes = query_tree_node_typed.getGroupBy().getNodes();
auto & projection_list = query_tree_node_typed.getProjection();
for (auto & node : projection_list.getNodes())
recursivelyCollectMaxOrdinaryExpressions(node, group_by_nodes);
}
/// Resolve identifier functions implementation
@ -6006,6 +6074,9 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier
node->removeAlias();
}
if (query_node_typed.isGroupByAll())
expandGroupByAll(query_node_typed);
/** Validate aggregates
*
* 1. Check that there are no aggregate functions and GROUPING function in JOIN TREE, WHERE, PREWHERE, in another aggregate functions.

View File

@ -54,6 +54,9 @@ void QueryNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, s
if (is_group_by_with_totals)
buffer << ", is_group_by_with_totals: " << is_group_by_with_totals;
if (is_group_by_all)
buffer << ", is_group_by_all: " << is_group_by_all;
std::string group_by_type;
if (is_group_by_with_rollup)
group_by_type = "rollup";
@ -117,7 +120,7 @@ void QueryNode::dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, s
getWhere()->dumpTreeImpl(buffer, format_state, indent + 4);
}
if (hasGroupBy())
if (!is_group_by_all && hasGroupBy())
{
buffer << '\n' << std::string(indent + 2, ' ') << "GROUP BY\n";
getGroupBy().dumpTreeImpl(buffer, format_state, indent + 4);
@ -198,7 +201,8 @@ bool QueryNode::isEqualImpl(const IQueryTreeNode & rhs) const
is_group_by_with_totals == rhs_typed.is_group_by_with_totals &&
is_group_by_with_rollup == rhs_typed.is_group_by_with_rollup &&
is_group_by_with_cube == rhs_typed.is_group_by_with_cube &&
is_group_by_with_grouping_sets == rhs_typed.is_group_by_with_grouping_sets;
is_group_by_with_grouping_sets == rhs_typed.is_group_by_with_grouping_sets &&
is_group_by_all == rhs_typed.is_group_by_all;
}
void QueryNode::updateTreeHashImpl(HashState & state) const
@ -226,6 +230,7 @@ void QueryNode::updateTreeHashImpl(HashState & state) const
state.update(is_group_by_with_rollup);
state.update(is_group_by_with_cube);
state.update(is_group_by_with_grouping_sets);
state.update(is_group_by_all);
if (constant_value)
{
@ -251,6 +256,7 @@ QueryTreeNodePtr QueryNode::cloneImpl() const
result_query_node->is_group_by_with_rollup = is_group_by_with_rollup;
result_query_node->is_group_by_with_cube = is_group_by_with_cube;
result_query_node->is_group_by_with_grouping_sets = is_group_by_with_grouping_sets;
result_query_node->is_group_by_all = is_group_by_all;
result_query_node->cte_name = cte_name;
result_query_node->projection_columns = projection_columns;
result_query_node->constant_value = constant_value;
@ -267,6 +273,7 @@ ASTPtr QueryNode::toASTImpl() const
select_query->group_by_with_rollup = is_group_by_with_rollup;
select_query->group_by_with_cube = is_group_by_with_cube;
select_query->group_by_with_grouping_sets = is_group_by_with_grouping_sets;
select_query->group_by_all = is_group_by_all;
if (hasWith())
select_query->setExpression(ASTSelectQuery::Expression::WITH, getWith().toAST());
@ -283,7 +290,7 @@ ASTPtr QueryNode::toASTImpl() const
if (getWhere())
select_query->setExpression(ASTSelectQuery::Expression::WHERE, getWhere()->toAST());
if (hasGroupBy())
if (!is_group_by_all && hasGroupBy())
select_query->setExpression(ASTSelectQuery::Expression::GROUP_BY, getGroupBy().toAST());
if (hasHaving())

View File

@ -176,6 +176,18 @@ public:
is_group_by_with_grouping_sets = is_group_by_with_grouping_sets_value;
}
/// Returns true, if query node has GROUP BY ALL modifier, false otherwise
bool isGroupByAll() const
{
return is_group_by_all;
}
/// Set query node GROUP BY ALL modifier value
void setIsGroupByAll(bool is_group_by_all_value)
{
is_group_by_all = is_group_by_all_value;
}
/// Returns true if query node WITH section is not empty, false otherwise
bool hasWith() const
{
@ -580,6 +592,7 @@ private:
bool is_group_by_with_rollup = false;
bool is_group_by_with_cube = false;
bool is_group_by_with_grouping_sets = false;
bool is_group_by_all = false;
std::string cte_name;
NamesAndTypes projection_columns;

View File

@ -215,6 +215,7 @@ QueryTreeNodePtr QueryTreeBuilder::buildSelectExpression(const ASTPtr & select_q
current_query_tree->setIsGroupByWithCube(select_query_typed.group_by_with_cube);
current_query_tree->setIsGroupByWithRollup(select_query_typed.group_by_with_rollup);
current_query_tree->setIsGroupByWithGroupingSets(select_query_typed.group_by_with_grouping_sets);
current_query_tree->setIsGroupByAll(select_query_typed.group_by_all);
current_query_tree->setOriginalAST(select_query);
auto select_settings = select_query_typed.settings();

View File

@ -1,8 +1,8 @@
#include <algorithm>
#include <memory>
#include <Core/Settings.h>
#include <Core/NamesAndTypes.h>
#include <Core/SettingsEnums.h>
#include <Interpreters/ArrayJoinedColumnsVisitor.h>
@ -45,10 +45,10 @@
#include <DataTypes/NestedUtils.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/WriteHelpers.h>
#include <Storages/IStorage.h>
#include <Common/checkStackSize.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
@ -784,6 +784,67 @@ void collectJoinedColumns(TableJoin & analyzed_join, ASTTableJoin & table_join,
}
}
std::pair<bool, UInt64> recursivelyCollectMaxOrdinaryExpressions(const ASTPtr & expr, ASTExpressionList & into)
{
checkStackSize();
if (expr->as<ASTIdentifier>())
{
into.children.push_back(expr);
return {false, 1};
}
auto * function = expr->as<ASTFunction>();
if (!function)
return {false, 0};
if (AggregateUtils::isAggregateFunction(*function))
return {true, 0};
UInt64 pushed_children = 0;
bool has_aggregate = false;
for (const auto & child : function->arguments->children)
{
auto [child_has_aggregate, child_pushed_children] = recursivelyCollectMaxOrdinaryExpressions(child, into);
has_aggregate |= child_has_aggregate;
pushed_children += child_pushed_children;
}
/// The current function is not aggregate function and there is no aggregate function in its arguments,
/// so use the current function to replace its arguments
if (!has_aggregate)
{
for (UInt64 i = 0; i < pushed_children; i++)
into.children.pop_back();
into.children.push_back(expr);
pushed_children = 1;
}
return {has_aggregate, pushed_children};
}
/** Expand GROUP BY ALL by extracting all the SELECT-ed expressions that are not aggregate functions.
*
* For a special case that if there is a function having both aggregate functions and other fields as its arguments,
* the `GROUP BY` keys will contain the maximum non-aggregate fields we can extract from it.
*
* Example:
* SELECT substring(a, 4, 2), substring(substring(a, 1, 2), 1, count(b)) FROM t GROUP BY ALL
* will expand as
* SELECT substring(a, 4, 2), substring(substring(a, 1, 2), 1, count(b)) FROM t GROUP BY substring(a, 4, 2), substring(a, 1, 2)
*/
void expandGroupByAll(ASTSelectQuery * select_query)
{
auto group_expression_list = std::make_shared<ASTExpressionList>();
for (const auto & expr : select_query->select()->children)
recursivelyCollectMaxOrdinaryExpressions(expr, *group_expression_list);
select_query->setExpression(ASTSelectQuery::Expression::GROUP_BY, group_expression_list);
}
std::vector<const ASTFunction *> getAggregates(ASTPtr & query, const ASTSelectQuery & select_query)
{
@ -1276,6 +1337,10 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect(
normalize(query, result.aliases, all_source_columns_set, select_options.ignore_alias, settings, /* allow_self_aliases = */ true, getContext());
// expand GROUP BY ALL
if (select_query->group_by_all)
expandGroupByAll(select_query);
/// Remove unneeded columns according to 'required_result_columns'.
/// Leave all selected columns in case of DISTINCT; columns that contain arrayJoin function inside.
/// Must be after 'normalizeTree' (after expanding aliases, for aliases not get lost)

View File

@ -93,7 +93,7 @@ void ASTSelectQuery::formatImpl(const FormatSettings & s, FormatState & state, F
where()->formatImpl(s, state, frame);
}
if (groupBy())
if (!group_by_all && groupBy())
{
s.ostr << (s.hilite ? hilite_keyword : "") << s.nl_or_ws << indent_str << "GROUP BY" << (s.hilite ? hilite_none : "");
if (!group_by_with_grouping_sets)
@ -104,6 +104,9 @@ void ASTSelectQuery::formatImpl(const FormatSettings & s, FormatState & state, F
}
}
if (group_by_all)
s.ostr << (s.hilite ? hilite_keyword : "") << s.nl_or_ws << indent_str << "GROUP BY ALL" << (s.hilite ? hilite_none : "");
if (group_by_with_rollup)
s.ostr << (s.hilite ? hilite_keyword : "") << s.nl_or_ws << indent_str << (s.one_line ? "" : " ") << "WITH ROLLUP" << (s.hilite ? hilite_none : "");

View File

@ -82,6 +82,7 @@ public:
ASTPtr clone() const override;
bool distinct = false;
bool group_by_all = false;
bool group_by_with_totals = false;
bool group_by_with_rollup = false;
bool group_by_with_cube = false;

View File

@ -195,6 +195,8 @@ bool ParserSelectQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
select_query->group_by_with_cube = true;
else if (s_grouping_sets.ignore(pos, expected))
select_query->group_by_with_grouping_sets = true;
else if (s_all.ignore(pos, expected))
select_query->group_by_all = true;
if ((select_query->group_by_with_rollup || select_query->group_by_with_cube || select_query->group_by_with_grouping_sets) &&
!open_bracket.ignore(pos, expected))
@ -205,7 +207,7 @@ bool ParserSelectQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
if (!grouping_sets_list.parse(pos, group_expression_list, expected))
return false;
}
else
else if (!select_query->group_by_all)
{
if (!exp_list.parse(pos, group_expression_list, expected))
return false;

View File

@ -0,0 +1,44 @@
abc1 1
abc2 1
abc3 1
abc4 1
abc 4
abc ab
abc ab
abc ab
abc bc
abc bc
abc a
abc a
abc a
abc a
abc a
abc a
abc a
abc a
1 abc a
1 abc a
1 abc a
1 abc a
abc1 1
abc2 1
abc3 1
abc4 1
abc 4
abc ab
abc ab
abc ab
abc bc
abc bc
abc a
abc a
abc a
abc a
abc a
abc a
abc a
abc a
1 abc a
1 abc a
1 abc a
1 abc a

View File

@ -0,0 +1,35 @@
DROP TABLE IF EXISTS group_by_all;
CREATE TABLE group_by_all
(
a String,
b int,
c int
)
engine = Memory;
insert into group_by_all values ('abc1', 1, 1), ('abc2', 1, 1), ('abc3', 1, 1), ('abc4', 1, 1);
select a, count(b) from group_by_all group by all order by a;
select substring(a, 1, 3), count(b) from group_by_all group by all;
select substring(a, 1, 3), substring(substring(a, 1, 2), 1, count(b)) from group_by_all group by all;
select substring(a, 1, 3), substring(substring(a, 1, 2), c, count(b)) from group_by_all group by all;
select substring(a, 1, 3), substring(substring(a, c, 2), c, count(b)) from group_by_all group by all;
select substring(a, 1, 3), substring(substring(a, c + 1, 2), 1, count(b)) from group_by_all group by all;
select substring(a, 1, 3), substring(substring(a, c + 1, 2), c, count(b)) from group_by_all group by all;
select substring(a, 1, 3), substring(substring(substring(a, c, count(b)), 1, count(b)), 1, count(b)) from group_by_all group by all;
select substring(a, 1, 3), substring(a, 1, count(b)) from group_by_all group by all;
select count(b) AS len, substring(a, 1, 3), substring(a, 1, len) from group_by_all group by all;
SET allow_experimental_analyzer = 1;
select a, count(b) from group_by_all group by all order by a;
select substring(a, 1, 3), count(b) from group_by_all group by all;
select substring(a, 1, 3), substring(substring(a, 1, 2), 1, count(b)) from group_by_all group by all;
select substring(a, 1, 3), substring(substring(a, 1, 2), c, count(b)) from group_by_all group by all;
select substring(a, 1, 3), substring(substring(a, c, 2), c, count(b)) from group_by_all group by all;
select substring(a, 1, 3), substring(substring(a, c + 1, 2), 1, count(b)) from group_by_all group by all;
select substring(a, 1, 3), substring(substring(a, c + 1, 2), c, count(b)) from group_by_all group by all;
select substring(a, 1, 3), substring(substring(substring(a, c, count(b)), 1, count(b)), 1, count(b)) from group_by_all group by all;
select substring(a, 1, 3), substring(a, 1, count(b)) from group_by_all group by all;
select count(b) AS len, substring(a, 1, 3), substring(a, 1, len) from group_by_all group by all;