Backport #62922 to 24.3: group_by_use_nulls strikes back

This commit is contained in:
robot-clickhouse 2024-05-14 10:08:19 +00:00
parent db75a12c99
commit fe418eb7be
4 changed files with 191 additions and 42 deletions

View File

@ -201,8 +201,11 @@ public:
void convertToNullable() override
{
chassert(kind == FunctionKind::ORDINARY);
wrap_with_nullable = true;
/// Ignore other function kinds.
/// We might try to convert aggregate/window function for invalid query
/// before the validation happened.
if (kind == FunctionKind::ORDINARY)
wrap_with_nullable = true;
}
void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override;

View File

@ -474,7 +474,7 @@ struct TableExpressionData
class ExpressionsStack
{
public:
void pushNode(const QueryTreeNodePtr & node)
void push(const QueryTreeNodePtr & node)
{
if (node->hasAlias())
{
@ -491,7 +491,7 @@ public:
expressions.emplace_back(node);
}
void popNode()
void pop()
{
const auto & top_expression = expressions.back();
const auto & top_expression_alias = top_expression->getAlias();
@ -729,6 +729,8 @@ struct IdentifierResolveScope
join_use_nulls = context->getSettingsRef().join_use_nulls;
else if (parent_scope)
join_use_nulls = parent_scope->join_use_nulls;
alias_name_to_expression_node = &alias_name_to_expression_node_before_group_by;
}
QueryTreeNodePtr scope_node;
@ -744,7 +746,10 @@ struct IdentifierResolveScope
std::unordered_map<std::string, QueryTreeNodePtr> expression_argument_name_to_node;
/// Alias name to query expression node
std::unordered_map<std::string, QueryTreeNodePtr> alias_name_to_expression_node;
std::unordered_map<std::string, QueryTreeNodePtr> alias_name_to_expression_node_before_group_by;
std::unordered_map<std::string, QueryTreeNodePtr> alias_name_to_expression_node_after_group_by;
std::unordered_map<std::string, QueryTreeNodePtr> * alias_name_to_expression_node = nullptr;
/// Alias name to lambda node
std::unordered_map<std::string, QueryTreeNodePtr> alias_name_to_lambda_node;
@ -877,6 +882,22 @@ struct IdentifierResolveScope
return it->second;
}
void pushExpressionNode(const QueryTreeNodePtr & node)
{
bool had_aggregate_function = expressions_in_resolve_process_stack.hasAggregateFunction();
expressions_in_resolve_process_stack.push(node);
if (group_by_use_nulls && had_aggregate_function != expressions_in_resolve_process_stack.hasAggregateFunction())
alias_name_to_expression_node = &alias_name_to_expression_node_before_group_by;
}
void popExpressionNode()
{
bool had_aggregate_function = expressions_in_resolve_process_stack.hasAggregateFunction();
expressions_in_resolve_process_stack.pop();
if (group_by_use_nulls && had_aggregate_function != expressions_in_resolve_process_stack.hasAggregateFunction())
alias_name_to_expression_node = &alias_name_to_expression_node_after_group_by;
}
/// Dump identifier resolve scope
[[maybe_unused]] void dump(WriteBuffer & buffer) const
{
@ -893,8 +914,8 @@ struct IdentifierResolveScope
for (const auto & [alias_name, node] : expression_argument_name_to_node)
buffer << "Alias name " << alias_name << " node " << node->formatASTForErrorMessage() << '\n';
buffer << "Alias name to expression node table size " << alias_name_to_expression_node.size() << '\n';
for (const auto & [alias_name, node] : alias_name_to_expression_node)
buffer << "Alias name to expression node table size " << alias_name_to_expression_node->size() << '\n';
for (const auto & [alias_name, node] : *alias_name_to_expression_node)
buffer << "Alias name " << alias_name << " expression node " << node->dumpTree() << '\n';
buffer << "Alias name to function node table size " << alias_name_to_lambda_node.size() << '\n';
@ -1022,7 +1043,7 @@ private:
if (is_lambda_node)
{
if (scope.alias_name_to_expression_node.contains(alias))
if (scope.alias_name_to_expression_node->contains(alias))
scope.nodes_with_duplicated_aliases.insert(node);
auto [_, inserted] = scope.alias_name_to_lambda_node.insert(std::make_pair(alias, node));
@ -1035,7 +1056,7 @@ private:
if (scope.alias_name_to_lambda_node.contains(alias))
scope.nodes_with_duplicated_aliases.insert(node);
auto [_, inserted] = scope.alias_name_to_expression_node.insert(std::make_pair(alias, node));
auto [_, inserted] = scope.alias_name_to_expression_node->insert(std::make_pair(alias, node));
if (!inserted)
scope.nodes_with_duplicated_aliases.insert(node);
@ -1837,7 +1858,7 @@ void QueryAnalyzer::collectScopeValidIdentifiersForTypoCorrection(
if (allow_expression_identifiers)
{
for (const auto & [name, expression] : scope.alias_name_to_expression_node)
for (const auto & [name, expression] : *scope.alias_name_to_expression_node)
{
assert(expression);
auto expression_identifier = Identifier(name);
@ -1867,7 +1888,7 @@ void QueryAnalyzer::collectScopeValidIdentifiersForTypoCorrection(
{
if (allow_function_identifiers)
{
for (const auto & [name, _] : scope.alias_name_to_expression_node)
for (const auto & [name, _] : *scope.alias_name_to_expression_node)
valid_identifiers_result.insert(Identifier(name));
}
@ -2768,7 +2789,7 @@ bool QueryAnalyzer::tryBindIdentifierToAliases(const IdentifierLookup & identifi
auto get_alias_name_to_node_map = [&]() -> const std::unordered_map<std::string, QueryTreeNodePtr> &
{
if (identifier_lookup.isExpressionLookup())
return scope.alias_name_to_expression_node;
return *scope.alias_name_to_expression_node;
else if (identifier_lookup.isFunctionLookup())
return scope.alias_name_to_lambda_node;
@ -2830,7 +2851,7 @@ QueryTreeNodePtr QueryAnalyzer::tryResolveIdentifierFromAliases(const Identifier
auto get_alias_name_to_node_map = [&]() -> std::unordered_map<std::string, QueryTreeNodePtr> &
{
if (identifier_lookup.isExpressionLookup())
return scope.alias_name_to_expression_node;
return *scope.alias_name_to_expression_node;
else if (identifier_lookup.isFunctionLookup())
return scope.alias_name_to_lambda_node;
@ -2868,7 +2889,7 @@ QueryTreeNodePtr QueryAnalyzer::tryResolveIdentifierFromAliases(const Identifier
/// Resolve expression if necessary
if (node_type == QueryTreeNodeType::IDENTIFIER)
{
scope.expressions_in_resolve_process_stack.pushNode(it->second);
scope.pushExpressionNode(it->second);
auto & alias_identifier_node = it->second->as<IdentifierNode &>();
auto identifier = alias_identifier_node.getIdentifier();
@ -2899,9 +2920,9 @@ QueryTreeNodePtr QueryAnalyzer::tryResolveIdentifierFromAliases(const Identifier
if (identifier_lookup.isExpressionLookup())
scope.alias_name_to_lambda_node.erase(identifier_bind_part);
else if (identifier_lookup.isFunctionLookup())
scope.alias_name_to_expression_node.erase(identifier_bind_part);
scope.alias_name_to_expression_node->erase(identifier_bind_part);
scope.expressions_in_resolve_process_stack.popNode();
scope.popExpressionNode();
}
else if (node_type == QueryTreeNodeType::FUNCTION)
{
@ -4097,6 +4118,21 @@ IdentifierResolveResult QueryAnalyzer::tryResolveIdentifier(const IdentifierLook
{
bool prefer_column_name_to_alias = scope.context->getSettingsRef().prefer_column_name_to_alias;
if (identifier_lookup.isExpressionLookup())
{
/* For aliases from ARRAY JOIN we prefer column from join tree:
* SELECT id FROM ( SELECT ... ) AS subquery ARRAY JOIN [0] AS id INNER JOIN second_table USING (id)
* In the example, identifier `id` should be resolved into one from USING (id) column.
*/
auto alias_it = scope.alias_name_to_expression_node->find(identifier_lookup.identifier.getFullName());
if (alias_it != scope.alias_name_to_expression_node->end() && alias_it->second->getNodeType() == QueryTreeNodeType::COLUMN)
{
const auto & column_node = alias_it->second->as<ColumnNode &>();
if (column_node.getColumnSource()->getNodeType() == QueryTreeNodeType::ARRAY_JOIN)
prefer_column_name_to_alias = true;
}
}
if (unlikely(prefer_column_name_to_alias))
{
if (identifier_resolve_settings.allow_to_check_join_tree)
@ -5193,10 +5229,14 @@ ProjectionNames QueryAnalyzer::resolveLambda(const QueryTreeNodePtr & lambda_nod
for (size_t i = 0; i < lambda_arguments_nodes_size; ++i)
{
auto & lambda_argument_node = lambda_arguments_nodes[i];
auto & lambda_argument_node_typed = lambda_argument_node->as<IdentifierNode &>();
const auto & lambda_argument_name = lambda_argument_node_typed.getIdentifier().getFullName();
const auto * lambda_argument_identifier = lambda_argument_node->as<IdentifierNode>();
const auto * lambda_argument_column = lambda_argument_node->as<ColumnNode>();
if (!lambda_argument_identifier && !lambda_argument_column)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected IDENTIFIER or COLUMN as lambda argument, got {}", lambda_node->dumpTree());
const auto & lambda_argument_name = lambda_argument_identifier ? lambda_argument_identifier->getIdentifier().getFullName()
: lambda_argument_column->getColumnName();
bool has_expression_node = scope.alias_name_to_expression_node.contains(lambda_argument_name);
bool has_expression_node = scope.alias_name_to_expression_node->contains(lambda_argument_name);
bool has_alias_node = scope.alias_name_to_lambda_node.contains(lambda_argument_name);
if (has_expression_node || has_alias_node)
@ -5204,7 +5244,7 @@ ProjectionNames QueryAnalyzer::resolveLambda(const QueryTreeNodePtr & lambda_nod
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Alias name '{}' inside lambda {} cannot have same name as lambda argument. In scope {}",
lambda_argument_name,
lambda_argument_node_typed.formatASTForErrorMessage(),
lambda_argument_node->formatASTForErrorMessage(),
scope.scope_node->formatASTForErrorMessage());
}
@ -6222,8 +6262,8 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id
*
* To resolve b we need to resolve a.
*/
auto it = scope.alias_name_to_expression_node.find(node_alias);
if (it != scope.alias_name_to_expression_node.end())
auto it = scope.alias_name_to_expression_node->find(node_alias);
if (it != scope.alias_name_to_expression_node->end())
node = it->second;
if (allow_lambda_expression)
@ -6234,7 +6274,7 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id
}
}
scope.expressions_in_resolve_process_stack.pushNode(node);
scope.pushExpressionNode(node);
auto node_type = node->getNodeType();
@ -6263,7 +6303,7 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id
resolved_identifier_node = tryResolveIdentifier({unresolved_identifier, IdentifierLookupContext::FUNCTION}, scope).resolved_identifier;
if (resolved_identifier_node && !node_alias.empty())
scope.alias_name_to_expression_node.erase(node_alias);
scope.alias_name_to_expression_node->erase(node_alias);
}
if (!resolved_identifier_node && allow_table_expression)
@ -6479,13 +6519,23 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id
validateTreeSize(node, scope.context->getSettingsRef().max_expanded_ast_elements, node_to_tree_size);
if (!scope.expressions_in_resolve_process_stack.hasAggregateFunction())
/// Lambda can be inside the aggregate function, so we should check parent scopes.
/// Most likely only the root scope can have an arrgegate function, but let's check all just in case.
bool in_aggregate_function_scope = false;
for (const auto * scope_ptr = &scope; scope_ptr; scope_ptr = scope_ptr->parent_scope)
in_aggregate_function_scope = in_aggregate_function_scope || scope_ptr->expressions_in_resolve_process_stack.hasAggregateFunction();
if (!in_aggregate_function_scope)
{
auto it = scope.nullable_group_by_keys.find(node);
if (it != scope.nullable_group_by_keys.end())
for (const auto * scope_ptr = &scope; scope_ptr; scope_ptr = scope_ptr->parent_scope)
{
node = it->node->clone();
node->convertToNullable();
auto it = scope_ptr->nullable_group_by_keys.find(node);
if (it != scope_ptr->nullable_group_by_keys.end())
{
node = it->node->clone();
node->convertToNullable();
break;
}
}
}
@ -6494,8 +6544,8 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id
*/
if (!node_alias.empty() && use_alias_table && !scope.group_by_use_nulls)
{
auto it = scope.alias_name_to_expression_node.find(node_alias);
if (it != scope.alias_name_to_expression_node.end())
auto it = scope.alias_name_to_expression_node->find(node_alias);
if (it != scope.alias_name_to_expression_node->end())
it->second = node;
if (allow_lambda_expression)
@ -6508,7 +6558,7 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id
resolved_expressions.emplace(node, result_projection_names);
scope.expressions_in_resolve_process_stack.popNode();
scope.popExpressionNode();
bool expression_was_root = scope.expressions_in_resolve_process_stack.empty();
if (expression_was_root)
scope.non_cached_identifier_lookups_during_expression_resolve.clear();
@ -6852,11 +6902,11 @@ void QueryAnalyzer::initializeQueryJoinTreeNode(QueryTreeNodePtr & join_tree_nod
*/
resolve_settings.allow_to_resolve_subquery_during_identifier_resolution = false;
scope.expressions_in_resolve_process_stack.pushNode(current_join_tree_node);
scope.pushExpressionNode(current_join_tree_node);
auto table_identifier_resolve_result = tryResolveIdentifier(table_identifier_lookup, scope, resolve_settings);
scope.expressions_in_resolve_process_stack.popNode();
scope.popExpressionNode();
bool expression_was_root = scope.expressions_in_resolve_process_stack.empty();
if (expression_was_root)
scope.non_cached_identifier_lookups_during_expression_resolve.clear();
@ -7442,7 +7492,7 @@ void QueryAnalyzer::resolveArrayJoin(QueryTreeNodePtr & array_join_node, Identif
for (auto & array_join_expression : array_join_nodes)
{
auto array_join_expression_alias = array_join_expression->getAlias();
if (!array_join_expression_alias.empty() && scope.alias_name_to_expression_node.contains(array_join_expression_alias))
if (!array_join_expression_alias.empty() && scope.alias_name_to_expression_node->contains(array_join_expression_alias))
throw Exception(ErrorCodes::MULTIPLE_EXPRESSIONS_FOR_ALIAS,
"ARRAY JOIN expression {} with duplicate alias {}. In scope {}",
array_join_expression->formatASTForErrorMessage(),
@ -7536,8 +7586,8 @@ void QueryAnalyzer::resolveArrayJoin(QueryTreeNodePtr & array_join_node, Identif
array_join_nodes = std::move(array_join_column_expressions);
for (auto & array_join_column_expression : array_join_nodes)
{
auto it = scope.alias_name_to_expression_node.find(array_join_column_expression->getAlias());
if (it != scope.alias_name_to_expression_node.end())
auto it = scope.alias_name_to_expression_node->find(array_join_column_expression->getAlias());
if (it != scope.alias_name_to_expression_node->end())
{
auto & array_join_column_expression_typed = array_join_column_expression->as<ColumnNode &>();
auto array_join_column = std::make_shared<ColumnNode>(array_join_column_expression_typed.getColumn(),
@ -8061,8 +8111,10 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier
/// Clone is needed cause aliases share subtrees.
/// If not clone, the same (shared) subtree could be resolved again with different (Nullable) type
/// See 03023_group_by_use_nulls_analyzer_crashes
for (auto & [_, node] : scope.alias_name_to_expression_node)
node = node->clone();
for (auto & [key, node] : scope.alias_name_to_expression_node_before_group_by)
scope.alias_name_to_expression_node_after_group_by[key] = node->clone();
scope.alias_name_to_expression_node = &scope.alias_name_to_expression_node_after_group_by;
}
if (query_node_typed.hasHaving())
@ -8139,8 +8191,8 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier
bool has_node_in_alias_table = false;
auto it = scope.alias_name_to_expression_node.find(node_alias);
if (it != scope.alias_name_to_expression_node.end())
auto it = scope.alias_name_to_expression_node->find(node_alias);
if (it != scope.alias_name_to_expression_node->end())
{
has_node_in_alias_table = true;
@ -8199,7 +8251,7 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier
/// Remove aliases from expression and lambda nodes
for (auto & [_, node] : scope.alias_name_to_expression_node)
for (auto & [_, node] : *scope.alias_name_to_expression_node)
node->removeAlias();
for (auto & [_, node] : scope.alias_name_to_lambda_node)

View File

@ -66,3 +66,61 @@ a a
a a
a a
0 0
0 \N
1 2
1 \N
2 4
2 \N
\N 0
\N 2
\N 4
\N \N
0 0 nan
2 4 nan
1 2 nan
2 \N nan
0 \N nan
1 \N nan
\N 2 nan
\N 0 nan
\N 4 nan
\N \N nan
[]
['.']
['.','.']
['.','.','.']
['.','.','.','.']
['.','.','.','.','.']
['.','.','.','.','.','.']
['.','.','.','.','.','.','.']
['.','.','.','.','.','.','.','.']
['.','.','.','.','.','.','.','.','.']
[]
[]
[]
[]
[]
[]
[]
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
10

View File

@ -21,3 +21,39 @@ SELECT tuple(number + 1) AS x FROM numbers(10) GROUP BY number + 1, toString(x)
SELECT tuple(tuple(number)) AS x FROM numbers(10) WHERE toString(toUUID(tuple(number), NULL), x) GROUP BY number, (toString(x), number) WITH CUBE SETTINGS group_by_use_nulls = 1 FORMAT Null;
SELECT materialize('a'), 'a' AS key GROUP BY key WITH CUBE WITH TOTALS SETTINGS group_by_use_nulls = 1;
EXPLAIN QUERY TREE
SELECT a, b
FROM numbers(3)
GROUP BY number as a, (number + number) as b WITH CUBE
ORDER BY a, b format Null;
SELECT a, b
FROM numbers(3)
GROUP BY number as a, (number + number) as b WITH CUBE
ORDER BY a, b;
SELECT
a,
b,
cramersVBiasCorrected(a, b)
FROM numbers(3)
GROUP BY
number AS a,
number + number AS b
WITH CUBE
SETTINGS group_by_use_nulls = 1;
SELECT arrayMap(x -> '.', range(number % 10)) AS k FROM remote('127.0.0.{2,3}', numbers(10)) GROUP BY GROUPING SETS ((k)) ORDER BY k settings group_by_use_nulls=1;
SELECT count('Lambda as function parameter') AS c FROM (SELECT ignore(ignore('Lambda as function parameter', 28, 28, 28, 28, 28, 28), 28), materialize('Lambda as function parameter'), 28, 28, 'world', 5 FROM system.numbers WHERE ignore(materialize('Lambda as function parameter'), materialize(toLowCardinality(28)), 28, 28, 28, 28, toUInt128(28)) LIMIT 2) GROUP BY GROUPING SETS ((toLowCardinality(0)), (toLowCardinality(toNullable(28))), (1)) HAVING nullIf(c, 10) < 50 ORDER BY c ASC NULLS FIRST settings group_by_use_nulls=1; -- { serverError ILLEGAL_AGGREGATION }
SELECT arraySplit(x -> 0, []) WHERE materialize(1) GROUP BY (0, ignore('a')) WITH ROLLUP SETTINGS group_by_use_nulls = 1;
SELECT arraySplit(x -> toUInt8(number), []) from numbers(1) GROUP BY toUInt8(number) WITH ROLLUP SETTINGS group_by_use_nulls = 1;
SELECT arraySplit(number -> toUInt8(number), []) from numbers(1) GROUP BY toUInt8(number) WITH ROLLUP SETTINGS group_by_use_nulls = 1;
SELECT count(arraySplit(number -> toUInt8(number), [arraySplit(x -> toUInt8(number), [])])) FROM numbers(10) GROUP BY number, [number] WITH ROLLUP settings group_by_use_nulls=1; -- {serverError ILLEGAL_TYPE_OF_ARGUMENT}
SELECT count(arraySplit(x -> toUInt8(number), [])) FROM numbers(10) GROUP BY number, [number] WITH ROLLUP settings group_by_use_nulls=1;