From fe418eb7bed8d7a0c8b926d385b1e870f69e189c Mon Sep 17 00:00:00 2001 From: robot-clickhouse Date: Tue, 14 May 2024 10:08:19 +0000 Subject: [PATCH] Backport #62922 to 24.3: group_by_use_nulls strikes back --- src/Analyzer/FunctionNode.h | 7 +- src/Analyzer/Passes/QueryAnalysisPass.cpp | 132 ++++++++++++------ ...up_by_use_nulls_analyzer_crashes.reference | 58 ++++++++ ...23_group_by_use_nulls_analyzer_crashes.sql | 36 +++++ 4 files changed, 191 insertions(+), 42 deletions(-) diff --git a/src/Analyzer/FunctionNode.h b/src/Analyzer/FunctionNode.h index 8d14b7eeb0d..8abffcfc8ee 100644 --- a/src/Analyzer/FunctionNode.h +++ b/src/Analyzer/FunctionNode.h @@ -201,8 +201,11 @@ public: void convertToNullable() override { - chassert(kind == FunctionKind::ORDINARY); - wrap_with_nullable = true; + /// Ignore other function kinds. + /// We might try to convert aggregate/window function for invalid query + /// before the validation happened. + if (kind == FunctionKind::ORDINARY) + wrap_with_nullable = true; } void dumpTreeImpl(WriteBuffer & buffer, FormatState & format_state, size_t indent) const override; diff --git a/src/Analyzer/Passes/QueryAnalysisPass.cpp b/src/Analyzer/Passes/QueryAnalysisPass.cpp index 3967e6b067f..83d2b6d8480 100644 --- a/src/Analyzer/Passes/QueryAnalysisPass.cpp +++ b/src/Analyzer/Passes/QueryAnalysisPass.cpp @@ -474,7 +474,7 @@ struct TableExpressionData class ExpressionsStack { public: - void pushNode(const QueryTreeNodePtr & node) + void push(const QueryTreeNodePtr & node) { if (node->hasAlias()) { @@ -491,7 +491,7 @@ public: expressions.emplace_back(node); } - void popNode() + void pop() { const auto & top_expression = expressions.back(); const auto & top_expression_alias = top_expression->getAlias(); @@ -729,6 +729,8 @@ struct IdentifierResolveScope join_use_nulls = context->getSettingsRef().join_use_nulls; else if (parent_scope) join_use_nulls = parent_scope->join_use_nulls; + + alias_name_to_expression_node = &alias_name_to_expression_node_before_group_by; } QueryTreeNodePtr scope_node; @@ -744,7 +746,10 @@ struct IdentifierResolveScope std::unordered_map expression_argument_name_to_node; /// Alias name to query expression node - std::unordered_map alias_name_to_expression_node; + std::unordered_map alias_name_to_expression_node_before_group_by; + std::unordered_map alias_name_to_expression_node_after_group_by; + + std::unordered_map * alias_name_to_expression_node = nullptr; /// Alias name to lambda node std::unordered_map alias_name_to_lambda_node; @@ -877,6 +882,22 @@ struct IdentifierResolveScope return it->second; } + void pushExpressionNode(const QueryTreeNodePtr & node) + { + bool had_aggregate_function = expressions_in_resolve_process_stack.hasAggregateFunction(); + expressions_in_resolve_process_stack.push(node); + if (group_by_use_nulls && had_aggregate_function != expressions_in_resolve_process_stack.hasAggregateFunction()) + alias_name_to_expression_node = &alias_name_to_expression_node_before_group_by; + } + + void popExpressionNode() + { + bool had_aggregate_function = expressions_in_resolve_process_stack.hasAggregateFunction(); + expressions_in_resolve_process_stack.pop(); + if (group_by_use_nulls && had_aggregate_function != expressions_in_resolve_process_stack.hasAggregateFunction()) + alias_name_to_expression_node = &alias_name_to_expression_node_after_group_by; + } + /// Dump identifier resolve scope [[maybe_unused]] void dump(WriteBuffer & buffer) const { @@ -893,8 +914,8 @@ struct IdentifierResolveScope for (const auto & [alias_name, node] : expression_argument_name_to_node) buffer << "Alias name " << alias_name << " node " << node->formatASTForErrorMessage() << '\n'; - buffer << "Alias name to expression node table size " << alias_name_to_expression_node.size() << '\n'; - for (const auto & [alias_name, node] : alias_name_to_expression_node) + buffer << "Alias name to expression node table size " << alias_name_to_expression_node->size() << '\n'; + for (const auto & [alias_name, node] : *alias_name_to_expression_node) buffer << "Alias name " << alias_name << " expression node " << node->dumpTree() << '\n'; buffer << "Alias name to function node table size " << alias_name_to_lambda_node.size() << '\n'; @@ -1022,7 +1043,7 @@ private: if (is_lambda_node) { - if (scope.alias_name_to_expression_node.contains(alias)) + if (scope.alias_name_to_expression_node->contains(alias)) scope.nodes_with_duplicated_aliases.insert(node); auto [_, inserted] = scope.alias_name_to_lambda_node.insert(std::make_pair(alias, node)); @@ -1035,7 +1056,7 @@ private: if (scope.alias_name_to_lambda_node.contains(alias)) scope.nodes_with_duplicated_aliases.insert(node); - auto [_, inserted] = scope.alias_name_to_expression_node.insert(std::make_pair(alias, node)); + auto [_, inserted] = scope.alias_name_to_expression_node->insert(std::make_pair(alias, node)); if (!inserted) scope.nodes_with_duplicated_aliases.insert(node); @@ -1837,7 +1858,7 @@ void QueryAnalyzer::collectScopeValidIdentifiersForTypoCorrection( if (allow_expression_identifiers) { - for (const auto & [name, expression] : scope.alias_name_to_expression_node) + for (const auto & [name, expression] : *scope.alias_name_to_expression_node) { assert(expression); auto expression_identifier = Identifier(name); @@ -1867,7 +1888,7 @@ void QueryAnalyzer::collectScopeValidIdentifiersForTypoCorrection( { if (allow_function_identifiers) { - for (const auto & [name, _] : scope.alias_name_to_expression_node) + for (const auto & [name, _] : *scope.alias_name_to_expression_node) valid_identifiers_result.insert(Identifier(name)); } @@ -2768,7 +2789,7 @@ bool QueryAnalyzer::tryBindIdentifierToAliases(const IdentifierLookup & identifi auto get_alias_name_to_node_map = [&]() -> const std::unordered_map & { if (identifier_lookup.isExpressionLookup()) - return scope.alias_name_to_expression_node; + return *scope.alias_name_to_expression_node; else if (identifier_lookup.isFunctionLookup()) return scope.alias_name_to_lambda_node; @@ -2830,7 +2851,7 @@ QueryTreeNodePtr QueryAnalyzer::tryResolveIdentifierFromAliases(const Identifier auto get_alias_name_to_node_map = [&]() -> std::unordered_map & { if (identifier_lookup.isExpressionLookup()) - return scope.alias_name_to_expression_node; + return *scope.alias_name_to_expression_node; else if (identifier_lookup.isFunctionLookup()) return scope.alias_name_to_lambda_node; @@ -2868,7 +2889,7 @@ QueryTreeNodePtr QueryAnalyzer::tryResolveIdentifierFromAliases(const Identifier /// Resolve expression if necessary if (node_type == QueryTreeNodeType::IDENTIFIER) { - scope.expressions_in_resolve_process_stack.pushNode(it->second); + scope.pushExpressionNode(it->second); auto & alias_identifier_node = it->second->as(); auto identifier = alias_identifier_node.getIdentifier(); @@ -2899,9 +2920,9 @@ QueryTreeNodePtr QueryAnalyzer::tryResolveIdentifierFromAliases(const Identifier if (identifier_lookup.isExpressionLookup()) scope.alias_name_to_lambda_node.erase(identifier_bind_part); else if (identifier_lookup.isFunctionLookup()) - scope.alias_name_to_expression_node.erase(identifier_bind_part); + scope.alias_name_to_expression_node->erase(identifier_bind_part); - scope.expressions_in_resolve_process_stack.popNode(); + scope.popExpressionNode(); } else if (node_type == QueryTreeNodeType::FUNCTION) { @@ -4097,6 +4118,21 @@ IdentifierResolveResult QueryAnalyzer::tryResolveIdentifier(const IdentifierLook { bool prefer_column_name_to_alias = scope.context->getSettingsRef().prefer_column_name_to_alias; + if (identifier_lookup.isExpressionLookup()) + { + /* For aliases from ARRAY JOIN we prefer column from join tree: + * SELECT id FROM ( SELECT ... ) AS subquery ARRAY JOIN [0] AS id INNER JOIN second_table USING (id) + * In the example, identifier `id` should be resolved into one from USING (id) column. + */ + auto alias_it = scope.alias_name_to_expression_node->find(identifier_lookup.identifier.getFullName()); + if (alias_it != scope.alias_name_to_expression_node->end() && alias_it->second->getNodeType() == QueryTreeNodeType::COLUMN) + { + const auto & column_node = alias_it->second->as(); + if (column_node.getColumnSource()->getNodeType() == QueryTreeNodeType::ARRAY_JOIN) + prefer_column_name_to_alias = true; + } + } + if (unlikely(prefer_column_name_to_alias)) { if (identifier_resolve_settings.allow_to_check_join_tree) @@ -5193,10 +5229,14 @@ ProjectionNames QueryAnalyzer::resolveLambda(const QueryTreeNodePtr & lambda_nod for (size_t i = 0; i < lambda_arguments_nodes_size; ++i) { auto & lambda_argument_node = lambda_arguments_nodes[i]; - auto & lambda_argument_node_typed = lambda_argument_node->as(); - const auto & lambda_argument_name = lambda_argument_node_typed.getIdentifier().getFullName(); + const auto * lambda_argument_identifier = lambda_argument_node->as(); + const auto * lambda_argument_column = lambda_argument_node->as(); + if (!lambda_argument_identifier && !lambda_argument_column) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected IDENTIFIER or COLUMN as lambda argument, got {}", lambda_node->dumpTree()); + const auto & lambda_argument_name = lambda_argument_identifier ? lambda_argument_identifier->getIdentifier().getFullName() + : lambda_argument_column->getColumnName(); - bool has_expression_node = scope.alias_name_to_expression_node.contains(lambda_argument_name); + bool has_expression_node = scope.alias_name_to_expression_node->contains(lambda_argument_name); bool has_alias_node = scope.alias_name_to_lambda_node.contains(lambda_argument_name); if (has_expression_node || has_alias_node) @@ -5204,7 +5244,7 @@ ProjectionNames QueryAnalyzer::resolveLambda(const QueryTreeNodePtr & lambda_nod throw Exception(ErrorCodes::BAD_ARGUMENTS, "Alias name '{}' inside lambda {} cannot have same name as lambda argument. In scope {}", lambda_argument_name, - lambda_argument_node_typed.formatASTForErrorMessage(), + lambda_argument_node->formatASTForErrorMessage(), scope.scope_node->formatASTForErrorMessage()); } @@ -6222,8 +6262,8 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id * * To resolve b we need to resolve a. */ - auto it = scope.alias_name_to_expression_node.find(node_alias); - if (it != scope.alias_name_to_expression_node.end()) + auto it = scope.alias_name_to_expression_node->find(node_alias); + if (it != scope.alias_name_to_expression_node->end()) node = it->second; if (allow_lambda_expression) @@ -6234,7 +6274,7 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id } } - scope.expressions_in_resolve_process_stack.pushNode(node); + scope.pushExpressionNode(node); auto node_type = node->getNodeType(); @@ -6263,7 +6303,7 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id resolved_identifier_node = tryResolveIdentifier({unresolved_identifier, IdentifierLookupContext::FUNCTION}, scope).resolved_identifier; if (resolved_identifier_node && !node_alias.empty()) - scope.alias_name_to_expression_node.erase(node_alias); + scope.alias_name_to_expression_node->erase(node_alias); } if (!resolved_identifier_node && allow_table_expression) @@ -6479,13 +6519,23 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id validateTreeSize(node, scope.context->getSettingsRef().max_expanded_ast_elements, node_to_tree_size); - if (!scope.expressions_in_resolve_process_stack.hasAggregateFunction()) + /// Lambda can be inside the aggregate function, so we should check parent scopes. + /// Most likely only the root scope can have an arrgegate function, but let's check all just in case. + bool in_aggregate_function_scope = false; + for (const auto * scope_ptr = &scope; scope_ptr; scope_ptr = scope_ptr->parent_scope) + in_aggregate_function_scope = in_aggregate_function_scope || scope_ptr->expressions_in_resolve_process_stack.hasAggregateFunction(); + + if (!in_aggregate_function_scope) { - auto it = scope.nullable_group_by_keys.find(node); - if (it != scope.nullable_group_by_keys.end()) + for (const auto * scope_ptr = &scope; scope_ptr; scope_ptr = scope_ptr->parent_scope) { - node = it->node->clone(); - node->convertToNullable(); + auto it = scope_ptr->nullable_group_by_keys.find(node); + if (it != scope_ptr->nullable_group_by_keys.end()) + { + node = it->node->clone(); + node->convertToNullable(); + break; + } } } @@ -6494,8 +6544,8 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id */ if (!node_alias.empty() && use_alias_table && !scope.group_by_use_nulls) { - auto it = scope.alias_name_to_expression_node.find(node_alias); - if (it != scope.alias_name_to_expression_node.end()) + auto it = scope.alias_name_to_expression_node->find(node_alias); + if (it != scope.alias_name_to_expression_node->end()) it->second = node; if (allow_lambda_expression) @@ -6508,7 +6558,7 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id resolved_expressions.emplace(node, result_projection_names); - scope.expressions_in_resolve_process_stack.popNode(); + scope.popExpressionNode(); bool expression_was_root = scope.expressions_in_resolve_process_stack.empty(); if (expression_was_root) scope.non_cached_identifier_lookups_during_expression_resolve.clear(); @@ -6852,11 +6902,11 @@ void QueryAnalyzer::initializeQueryJoinTreeNode(QueryTreeNodePtr & join_tree_nod */ resolve_settings.allow_to_resolve_subquery_during_identifier_resolution = false; - scope.expressions_in_resolve_process_stack.pushNode(current_join_tree_node); + scope.pushExpressionNode(current_join_tree_node); auto table_identifier_resolve_result = tryResolveIdentifier(table_identifier_lookup, scope, resolve_settings); - scope.expressions_in_resolve_process_stack.popNode(); + scope.popExpressionNode(); bool expression_was_root = scope.expressions_in_resolve_process_stack.empty(); if (expression_was_root) scope.non_cached_identifier_lookups_during_expression_resolve.clear(); @@ -7442,7 +7492,7 @@ void QueryAnalyzer::resolveArrayJoin(QueryTreeNodePtr & array_join_node, Identif for (auto & array_join_expression : array_join_nodes) { auto array_join_expression_alias = array_join_expression->getAlias(); - if (!array_join_expression_alias.empty() && scope.alias_name_to_expression_node.contains(array_join_expression_alias)) + if (!array_join_expression_alias.empty() && scope.alias_name_to_expression_node->contains(array_join_expression_alias)) throw Exception(ErrorCodes::MULTIPLE_EXPRESSIONS_FOR_ALIAS, "ARRAY JOIN expression {} with duplicate alias {}. In scope {}", array_join_expression->formatASTForErrorMessage(), @@ -7536,8 +7586,8 @@ void QueryAnalyzer::resolveArrayJoin(QueryTreeNodePtr & array_join_node, Identif array_join_nodes = std::move(array_join_column_expressions); for (auto & array_join_column_expression : array_join_nodes) { - auto it = scope.alias_name_to_expression_node.find(array_join_column_expression->getAlias()); - if (it != scope.alias_name_to_expression_node.end()) + auto it = scope.alias_name_to_expression_node->find(array_join_column_expression->getAlias()); + if (it != scope.alias_name_to_expression_node->end()) { auto & array_join_column_expression_typed = array_join_column_expression->as(); auto array_join_column = std::make_shared(array_join_column_expression_typed.getColumn(), @@ -8061,8 +8111,10 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier /// Clone is needed cause aliases share subtrees. /// If not clone, the same (shared) subtree could be resolved again with different (Nullable) type /// See 03023_group_by_use_nulls_analyzer_crashes - for (auto & [_, node] : scope.alias_name_to_expression_node) - node = node->clone(); + for (auto & [key, node] : scope.alias_name_to_expression_node_before_group_by) + scope.alias_name_to_expression_node_after_group_by[key] = node->clone(); + + scope.alias_name_to_expression_node = &scope.alias_name_to_expression_node_after_group_by; } if (query_node_typed.hasHaving()) @@ -8139,8 +8191,8 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier bool has_node_in_alias_table = false; - auto it = scope.alias_name_to_expression_node.find(node_alias); - if (it != scope.alias_name_to_expression_node.end()) + auto it = scope.alias_name_to_expression_node->find(node_alias); + if (it != scope.alias_name_to_expression_node->end()) { has_node_in_alias_table = true; @@ -8199,7 +8251,7 @@ void QueryAnalyzer::resolveQuery(const QueryTreeNodePtr & query_node, Identifier /// Remove aliases from expression and lambda nodes - for (auto & [_, node] : scope.alias_name_to_expression_node) + for (auto & [_, node] : *scope.alias_name_to_expression_node) node->removeAlias(); for (auto & [_, node] : scope.alias_name_to_lambda_node) diff --git a/tests/queries/0_stateless/03023_group_by_use_nulls_analyzer_crashes.reference b/tests/queries/0_stateless/03023_group_by_use_nulls_analyzer_crashes.reference index 17a17484a0c..02ea01eb2e6 100644 --- a/tests/queries/0_stateless/03023_group_by_use_nulls_analyzer_crashes.reference +++ b/tests/queries/0_stateless/03023_group_by_use_nulls_analyzer_crashes.reference @@ -66,3 +66,61 @@ a a a a a a +0 0 +0 \N +1 2 +1 \N +2 4 +2 \N +\N 0 +\N 2 +\N 4 +\N \N +0 0 nan +2 4 nan +1 2 nan +2 \N nan +0 \N nan +1 \N nan +\N 2 nan +\N 0 nan +\N 4 nan +\N \N nan +[] +['.'] +['.','.'] +['.','.','.'] +['.','.','.','.'] +['.','.','.','.','.'] +['.','.','.','.','.','.'] +['.','.','.','.','.','.','.'] +['.','.','.','.','.','.','.','.'] +['.','.','.','.','.','.','.','.','.'] +[] +[] +[] +[] +[] +[] +[] +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +10 diff --git a/tests/queries/0_stateless/03023_group_by_use_nulls_analyzer_crashes.sql b/tests/queries/0_stateless/03023_group_by_use_nulls_analyzer_crashes.sql index 68710137542..b8c173520a9 100644 --- a/tests/queries/0_stateless/03023_group_by_use_nulls_analyzer_crashes.sql +++ b/tests/queries/0_stateless/03023_group_by_use_nulls_analyzer_crashes.sql @@ -21,3 +21,39 @@ SELECT tuple(number + 1) AS x FROM numbers(10) GROUP BY number + 1, toString(x) SELECT tuple(tuple(number)) AS x FROM numbers(10) WHERE toString(toUUID(tuple(number), NULL), x) GROUP BY number, (toString(x), number) WITH CUBE SETTINGS group_by_use_nulls = 1 FORMAT Null; SELECT materialize('a'), 'a' AS key GROUP BY key WITH CUBE WITH TOTALS SETTINGS group_by_use_nulls = 1; + +EXPLAIN QUERY TREE +SELECT a, b +FROM numbers(3) +GROUP BY number as a, (number + number) as b WITH CUBE +ORDER BY a, b format Null; + +SELECT a, b +FROM numbers(3) +GROUP BY number as a, (number + number) as b WITH CUBE +ORDER BY a, b; + +SELECT + a, + b, + cramersVBiasCorrected(a, b) +FROM numbers(3) +GROUP BY + number AS a, + number + number AS b + WITH CUBE +SETTINGS group_by_use_nulls = 1; + +SELECT arrayMap(x -> '.', range(number % 10)) AS k FROM remote('127.0.0.{2,3}', numbers(10)) GROUP BY GROUPING SETS ((k)) ORDER BY k settings group_by_use_nulls=1; + +SELECT count('Lambda as function parameter') AS c FROM (SELECT ignore(ignore('Lambda as function parameter', 28, 28, 28, 28, 28, 28), 28), materialize('Lambda as function parameter'), 28, 28, 'world', 5 FROM system.numbers WHERE ignore(materialize('Lambda as function parameter'), materialize(toLowCardinality(28)), 28, 28, 28, 28, toUInt128(28)) LIMIT 2) GROUP BY GROUPING SETS ((toLowCardinality(0)), (toLowCardinality(toNullable(28))), (1)) HAVING nullIf(c, 10) < 50 ORDER BY c ASC NULLS FIRST settings group_by_use_nulls=1; -- { serverError ILLEGAL_AGGREGATION } + +SELECT arraySplit(x -> 0, []) WHERE materialize(1) GROUP BY (0, ignore('a')) WITH ROLLUP SETTINGS group_by_use_nulls = 1; + +SELECT arraySplit(x -> toUInt8(number), []) from numbers(1) GROUP BY toUInt8(number) WITH ROLLUP SETTINGS group_by_use_nulls = 1; + +SELECT arraySplit(number -> toUInt8(number), []) from numbers(1) GROUP BY toUInt8(number) WITH ROLLUP SETTINGS group_by_use_nulls = 1; + +SELECT count(arraySplit(number -> toUInt8(number), [arraySplit(x -> toUInt8(number), [])])) FROM numbers(10) GROUP BY number, [number] WITH ROLLUP settings group_by_use_nulls=1; -- {serverError ILLEGAL_TYPE_OF_ARGUMENT} + +SELECT count(arraySplit(x -> toUInt8(number), [])) FROM numbers(10) GROUP BY number, [number] WITH ROLLUP settings group_by_use_nulls=1; \ No newline at end of file