From e220906c9e8ff9345d17ea8b282dfb4d67ee60a7 Mon Sep 17 00:00:00 2001 From: Maksim Kita Date: Mon, 7 Nov 2022 13:51:19 +0100 Subject: [PATCH] Analyzer added identifier typo corrections --- src/Access/SettingsConstraints.cpp | 2 +- src/Analyzer/Identifier.h | 25 +++ src/Analyzer/Passes/QueryAnalysisPass.cpp | 196 +++++++++++++++++++++- src/Common/NamePrompter.cpp | 6 +- src/Common/NamePrompter.h | 8 +- 5 files changed, 220 insertions(+), 17 deletions(-) diff --git a/src/Access/SettingsConstraints.cpp b/src/Access/SettingsConstraints.cpp index d97a78c78ab..0317e43f8d1 100644 --- a/src/Access/SettingsConstraints.cpp +++ b/src/Access/SettingsConstraints.cpp @@ -147,7 +147,7 @@ bool SettingsConstraints::checkImpl(const Settings & current_settings, SettingCh { if (const auto hints = current_settings.getHints(change.name); !hints.empty()) { - e.addMessage(fmt::format("Maybe you meant {}", toString(hints))); + e.addMessage(fmt::format("Maybe you meant {}", toString(hints))); } } throw; diff --git a/src/Analyzer/Identifier.h b/src/Analyzer/Identifier.h index 2252ce2854f..abfee7cafc2 100644 --- a/src/Analyzer/Identifier.h +++ b/src/Analyzer/Identifier.h @@ -152,6 +152,11 @@ public: return popFirst(1); } + void pop_front() /// NOLINT + { + return popFirst(); + } + void popLast(size_t parts_to_remove_size) { assert(parts_to_remove_size <= parts.size()); @@ -365,6 +370,26 @@ inline std::ostream & operator<<(std::ostream & stream, const IdentifierView & i } +template <> +struct std::hash +{ + size_t operator()(const DB::Identifier & identifier) const + { + std::hash hash; + return hash(identifier.getFullName()); + } +}; + +template <> +struct std::hash +{ + size_t operator()(const DB::IdentifierView & identifier) const + { + std::hash hash; + return hash(identifier.getFullName()); + } +}; + /// See https://fmt.dev/latest/api.html#formatting-user-defined-types template <> diff --git a/src/Analyzer/Passes/QueryAnalysisPass.cpp b/src/Analyzer/Passes/QueryAnalysisPass.cpp index 5dbc8ffdb3c..76a51c05243 100644 --- a/src/Analyzer/Passes/QueryAnalysisPass.cpp +++ b/src/Analyzer/Passes/QueryAnalysisPass.cpp @@ -1,5 +1,7 @@ #include +#include + #include #include #include @@ -1056,6 +1058,27 @@ private: const ProjectionName & fill_to_expression_projection_name, const ProjectionName & fill_step_expression_projection_name); + static void getTableExpressionValidIdentifiersForTypoCorrection(const Identifier & unresolved_identifier, + const QueryTreeNodePtr & table_expression, + const TableExpressionData & table_expression_data, + std::unordered_set & valid_identifiers_result); + + static void getScopeValidIdentifiersForTypoCorrection(const Identifier & unresolved_identifier, + const IdentifierResolveScope & scope, + bool allow_expression_identifiers, + bool allow_function_identifiers, + bool allow_table_expression_identifiers, + std::unordered_set & valid_identifiers_result); + + static void getScopeWithParentScopesValidIdentifiersForTypoCorrection(const Identifier & unresolved_identifier, + const IdentifierResolveScope & scope, + bool allow_expression_identifiers, + bool allow_function_identifiers, + bool allow_table_expression_identifiers, + std::unordered_set & valid_identifiers_result); + + static std::vector getIdentifierTypoHints(const Identifier & unresolved_identifier, const std::unordered_set & valid_identifiers); + static QueryTreeNodePtr wrapExpressionNodeInTupleElement(QueryTreeNodePtr expression_node, IdentifierView nested_path); static QueryTreeNodePtr tryGetLambdaFromSQLUserDefinedFunctions(const std::string & function_name, ContextPtr context); @@ -1358,6 +1381,140 @@ ProjectionName QueryAnalyzer::calculateSortColumnProjectionName(const QueryTreeN return sort_column_projection_name_buffer.str(); } +/// Get valid identifiers for typo correction from table expression +void QueryAnalyzer::getTableExpressionValidIdentifiersForTypoCorrection(const Identifier & unresolved_identifier, + const QueryTreeNodePtr & table_expression, + const TableExpressionData & table_expression_data, + std::unordered_set & valid_identifiers_result) +{ + for (const auto & [column_name, _] : table_expression_data.column_name_to_column_node) + { + Identifier column_identifier(column_name); + if (unresolved_identifier.getPartsSize() == column_identifier.getPartsSize()) + valid_identifiers_result.insert(column_identifier); + + if (table_expression->hasAlias()) + { + Identifier column_identifier_with_alias({table_expression->getAlias()}); + for (const auto & column_identifier_part : column_identifier) + column_identifier_with_alias.push_back(column_identifier_part); + + if (unresolved_identifier.getPartsSize() == column_identifier_with_alias.getPartsSize()) + valid_identifiers_result.insert(column_identifier_with_alias); + } + + if (!table_expression_data.table_name.empty()) + { + Identifier column_identifier_with_table_name({table_expression_data.table_name}); + for (const auto & column_identifier_part : column_identifier) + column_identifier_with_table_name.push_back(column_identifier_part); + + if (unresolved_identifier.getPartsSize() == column_identifier_with_table_name.getPartsSize()) + valid_identifiers_result.insert(column_identifier_with_table_name); + } + + if (!table_expression_data.database_name.empty() && !table_expression_data.table_name.empty()) + { + Identifier column_identifier_with_table_name_and_database_name({table_expression_data.database_name, table_expression_data.table_name}); + for (const auto & column_identifier_part : column_identifier) + column_identifier_with_table_name_and_database_name.push_back(column_identifier_part); + + if (unresolved_identifier.getPartsSize() == column_identifier_with_table_name_and_database_name.getPartsSize()) + valid_identifiers_result.insert(column_identifier_with_table_name_and_database_name); + } + } +} + +/// Get valid identifiers for typo correction from scope without looking at parent scopes +void QueryAnalyzer::getScopeValidIdentifiersForTypoCorrection( + const Identifier & unresolved_identifier, + const IdentifierResolveScope & scope, + bool allow_expression_identifiers, + bool allow_function_identifiers, + bool allow_table_expression_identifiers, + std::unordered_set & valid_identifiers_result) +{ + if (allow_expression_identifiers) + { + if (unresolved_identifier.isShort()) + { + for (const auto & [name, _] : scope.alias_name_to_expression_node) + valid_identifiers_result.insert(Identifier(name)); + } + + for (const auto & [table_expression, table_expression_data] : scope.table_expression_node_to_data) + { + getTableExpressionValidIdentifiersForTypoCorrection(unresolved_identifier, + table_expression, + table_expression_data, + valid_identifiers_result); + } + + } + + if (allow_function_identifiers && unresolved_identifier.isShort()) + { + for (const auto & [name, _] : scope.alias_name_to_expression_node) + valid_identifiers_result.insert(Identifier(name)); + } + + if (allow_table_expression_identifiers && unresolved_identifier.isShort()) + { + for (const auto & [name, _] : scope.alias_name_to_table_expression_node) + valid_identifiers_result.insert(Identifier(name)); + } + + if (unresolved_identifier.isShort()) + { + for (const auto & [argument_name, expression] : scope.expression_argument_name_to_node) + { + auto expression_node_type = expression->getNodeType(); + + if (allow_expression_identifiers && isExpressionNodeType(expression_node_type)) + valid_identifiers_result.insert(Identifier(argument_name)); + else if (allow_function_identifiers && isFunctionExpressionNodeType(expression_node_type)) + valid_identifiers_result.insert(Identifier(argument_name)); + else if (allow_table_expression_identifiers && isTableExpressionNodeType(expression_node_type)) + valid_identifiers_result.insert(Identifier(argument_name)); + } + } +} + +void QueryAnalyzer::getScopeWithParentScopesValidIdentifiersForTypoCorrection( + const Identifier & unresolved_identifier, + const IdentifierResolveScope & scope, + bool allow_expression_identifiers, + bool allow_function_identifiers, + bool allow_table_expression_identifiers, + std::unordered_set & valid_identifiers_result) +{ + const IdentifierResolveScope * current_scope = &scope; + + while (current_scope) + { + getScopeValidIdentifiersForTypoCorrection(unresolved_identifier, + *current_scope, + allow_expression_identifiers, + allow_function_identifiers, + allow_table_expression_identifiers, + valid_identifiers_result); + + current_scope = current_scope->parent_scope; + } +} + +std::vector QueryAnalyzer::getIdentifierTypoHints(const Identifier & unresolved_identifier, const std::unordered_set & valid_identifiers) +{ + std::vector prompting_strings; + prompting_strings.reserve(valid_identifiers.size()); + + for (const auto & valid_identifier : valid_identifiers) + prompting_strings.push_back(valid_identifier.getFullName()); + + NamePrompter<1> prompter; + return prompter.getHints(unresolved_identifier.getFullName(), prompting_strings); +} + /** Wrap expression node in tuple element function calls for nested paths. * Example: Expression node: compound_expression. Nested path: nested_path_1.nested_path_2. * Result: tupleElement(tupleElement(compound_expression, 'nested_path_1'), 'nested_path_2'). @@ -2107,13 +2264,24 @@ QueryTreeNodePtr QueryAnalyzer::tryResolveIdentifierFromTableExpression(const Id } if (!result_column || (!match_full_identifier && !compound_identifier)) - throw Exception(ErrorCodes::BAD_ARGUMENTS, - "Identifier '{}' cannot be resolved from {}{}. In scope {}", + { + std::string error_message = fmt::format("Identifier '{}' cannot be resolved from {}{}. In scope {}", identifier.getFullName(), table_expression_data.table_expression_description, table_expression_data.table_expression_name.empty() ? "" : " with name " + table_expression_data.table_expression_name, scope.scope_node->formatASTForErrorMessage()); + std::unordered_set valid_identifiers; + getTableExpressionValidIdentifiersForTypoCorrection(identifier, + table_expression_node, + table_expression_data, + valid_identifiers); + auto hints = getIdentifierTypoHints(identifier, valid_identifiers); + appendHintsMessage(error_message, hints); + + throw Exception(ErrorCodes::BAD_ARGUMENTS, error_message); + } + QueryTreeNodePtr result_expression = result_column; bool clone_is_needed = true; @@ -4309,12 +4477,24 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id if (allow_table_expression) message_clarification = std::string(" or ") + toStringLowercase(IdentifierLookupContext::TABLE_EXPRESSION); - throw Exception(ErrorCodes::UNKNOWN_IDENTIFIER, - "Unknown {}{} identifier '{}' in scope {}", + std::string error_message = fmt::format("Unknown {}{} identifier '{}' in scope {}", toStringLowercase(IdentifierLookupContext::EXPRESSION), message_clarification, unresolved_identifier.getFullName(), scope.scope_node->formatASTForErrorMessage()); + + std::unordered_set valid_identifiers; + getScopeWithParentScopesValidIdentifiersForTypoCorrection(unresolved_identifier, + scope, + true, + allow_lambda_expression, + allow_table_expression, + valid_identifiers); + + auto hints = getIdentifierTypoHints(unresolved_identifier, valid_identifiers); + appendHintsMessage(error_message, hints); + + throw Exception(ErrorCodes::UNKNOWN_IDENTIFIER, error_message); } if (node->getNodeType() == QueryTreeNodeType::LIST) @@ -4876,17 +5056,15 @@ void QueryAnalyzer::initializeTableExpressionColumns(const QueryTreeNodePtr & ta { table_expression_data.table_name = query_node ? query_node->getCTEName() : union_node->getCTEName(); table_expression_data.table_expression_description = "subquery"; - - if (table_expression_node->hasAlias()) - table_expression_data.table_expression_name = table_expression_node->getAlias(); } else if (table_function_node) { table_expression_data.table_expression_description = "table_function"; - if (table_function_node->hasAlias()) - table_expression_data.table_expression_name = table_function_node->getAlias(); } + if (table_expression_node->hasAlias()) + table_expression_data.table_expression_name = table_expression_node->getAlias(); + if (table_node || table_function_node) { const auto & storage_snapshot = table_node ? table_node->getStorageSnapshot() : table_function_node->getStorageSnapshot(); diff --git a/src/Common/NamePrompter.cpp b/src/Common/NamePrompter.cpp index c5a2224dcb4..ea42b801ee2 100644 --- a/src/Common/NamePrompter.cpp +++ b/src/Common/NamePrompter.cpp @@ -1,9 +1,10 @@ #include #include -namespace DB::detail +namespace DB { -void appendHintsMessageImpl(String & message, const std::vector & hints) + +void appendHintsMessage(String & message, const std::vector & hints) { if (hints.empty()) { @@ -12,4 +13,5 @@ void appendHintsMessageImpl(String & message, const std::vector & hints) message += ". Maybe you meant: " + toString(hints); } + } diff --git a/src/Common/NamePrompter.h b/src/Common/NamePrompter.h index 962a89a8e76..8b69dd100ee 100644 --- a/src/Common/NamePrompter.h +++ b/src/Common/NamePrompter.h @@ -12,6 +12,7 @@ namespace DB { + template class NamePrompter { @@ -90,10 +91,7 @@ private: } }; -namespace detail -{ -void appendHintsMessageImpl(String & message, const std::vector & hints); -} +void appendHintsMessage(String & message, const std::vector & hints); template class IHints @@ -109,7 +107,7 @@ public: void appendHintsMessage(String & message, const String & name) const { auto hints = getHints(name); - detail::appendHintsMessageImpl(message, hints); + DB::appendHintsMessage(message, hints); } IHints() = default;