Analyzer added identifier typo corrections

This commit is contained in:
Maksim Kita 2022-11-07 13:51:19 +01:00
parent 80a13538ca
commit e220906c9e
5 changed files with 220 additions and 17 deletions

View File

@ -147,7 +147,7 @@ bool SettingsConstraints::checkImpl(const Settings & current_settings, SettingCh
{
if (const auto hints = current_settings.getHints(change.name); !hints.empty())
{
e.addMessage(fmt::format("Maybe you meant {}", toString(hints)));
e.addMessage(fmt::format("Maybe you meant {}", toString(hints)));
}
}
throw;

View File

@ -152,6 +152,11 @@ public:
return popFirst(1);
}
void pop_front() /// NOLINT
{
return popFirst();
}
void popLast(size_t parts_to_remove_size)
{
assert(parts_to_remove_size <= parts.size());
@ -365,6 +370,26 @@ inline std::ostream & operator<<(std::ostream & stream, const IdentifierView & i
}
template <>
struct std::hash<DB::Identifier>
{
size_t operator()(const DB::Identifier & identifier) const
{
std::hash<std::string> hash;
return hash(identifier.getFullName());
}
};
template <>
struct std::hash<DB::IdentifierView>
{
size_t operator()(const DB::IdentifierView & identifier) const
{
std::hash<std::string_view> hash;
return hash(identifier.getFullName());
}
};
/// See https://fmt.dev/latest/api.html#formatting-user-defined-types
template <>

View File

@ -1,5 +1,7 @@
#include <Analyzer/Passes/QueryAnalysisPass.h>
#include <Common/NamePrompter.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteHelpers.h>
#include <IO/Operators.h>
@ -1056,6 +1058,27 @@ private:
const ProjectionName & fill_to_expression_projection_name,
const ProjectionName & fill_step_expression_projection_name);
static void getTableExpressionValidIdentifiersForTypoCorrection(const Identifier & unresolved_identifier,
const QueryTreeNodePtr & table_expression,
const TableExpressionData & table_expression_data,
std::unordered_set<Identifier> & valid_identifiers_result);
static void getScopeValidIdentifiersForTypoCorrection(const Identifier & unresolved_identifier,
const IdentifierResolveScope & scope,
bool allow_expression_identifiers,
bool allow_function_identifiers,
bool allow_table_expression_identifiers,
std::unordered_set<Identifier> & valid_identifiers_result);
static void getScopeWithParentScopesValidIdentifiersForTypoCorrection(const Identifier & unresolved_identifier,
const IdentifierResolveScope & scope,
bool allow_expression_identifiers,
bool allow_function_identifiers,
bool allow_table_expression_identifiers,
std::unordered_set<Identifier> & valid_identifiers_result);
static std::vector<String> getIdentifierTypoHints(const Identifier & unresolved_identifier, const std::unordered_set<Identifier> & valid_identifiers);
static QueryTreeNodePtr wrapExpressionNodeInTupleElement(QueryTreeNodePtr expression_node, IdentifierView nested_path);
static QueryTreeNodePtr tryGetLambdaFromSQLUserDefinedFunctions(const std::string & function_name, ContextPtr context);
@ -1358,6 +1381,140 @@ ProjectionName QueryAnalyzer::calculateSortColumnProjectionName(const QueryTreeN
return sort_column_projection_name_buffer.str();
}
/// Get valid identifiers for typo correction from table expression
void QueryAnalyzer::getTableExpressionValidIdentifiersForTypoCorrection(const Identifier & unresolved_identifier,
const QueryTreeNodePtr & table_expression,
const TableExpressionData & table_expression_data,
std::unordered_set<Identifier> & valid_identifiers_result)
{
for (const auto & [column_name, _] : table_expression_data.column_name_to_column_node)
{
Identifier column_identifier(column_name);
if (unresolved_identifier.getPartsSize() == column_identifier.getPartsSize())
valid_identifiers_result.insert(column_identifier);
if (table_expression->hasAlias())
{
Identifier column_identifier_with_alias({table_expression->getAlias()});
for (const auto & column_identifier_part : column_identifier)
column_identifier_with_alias.push_back(column_identifier_part);
if (unresolved_identifier.getPartsSize() == column_identifier_with_alias.getPartsSize())
valid_identifiers_result.insert(column_identifier_with_alias);
}
if (!table_expression_data.table_name.empty())
{
Identifier column_identifier_with_table_name({table_expression_data.table_name});
for (const auto & column_identifier_part : column_identifier)
column_identifier_with_table_name.push_back(column_identifier_part);
if (unresolved_identifier.getPartsSize() == column_identifier_with_table_name.getPartsSize())
valid_identifiers_result.insert(column_identifier_with_table_name);
}
if (!table_expression_data.database_name.empty() && !table_expression_data.table_name.empty())
{
Identifier column_identifier_with_table_name_and_database_name({table_expression_data.database_name, table_expression_data.table_name});
for (const auto & column_identifier_part : column_identifier)
column_identifier_with_table_name_and_database_name.push_back(column_identifier_part);
if (unresolved_identifier.getPartsSize() == column_identifier_with_table_name_and_database_name.getPartsSize())
valid_identifiers_result.insert(column_identifier_with_table_name_and_database_name);
}
}
}
/// Get valid identifiers for typo correction from scope without looking at parent scopes
void QueryAnalyzer::getScopeValidIdentifiersForTypoCorrection(
const Identifier & unresolved_identifier,
const IdentifierResolveScope & scope,
bool allow_expression_identifiers,
bool allow_function_identifiers,
bool allow_table_expression_identifiers,
std::unordered_set<Identifier> & valid_identifiers_result)
{
if (allow_expression_identifiers)
{
if (unresolved_identifier.isShort())
{
for (const auto & [name, _] : scope.alias_name_to_expression_node)
valid_identifiers_result.insert(Identifier(name));
}
for (const auto & [table_expression, table_expression_data] : scope.table_expression_node_to_data)
{
getTableExpressionValidIdentifiersForTypoCorrection(unresolved_identifier,
table_expression,
table_expression_data,
valid_identifiers_result);
}
}
if (allow_function_identifiers && unresolved_identifier.isShort())
{
for (const auto & [name, _] : scope.alias_name_to_expression_node)
valid_identifiers_result.insert(Identifier(name));
}
if (allow_table_expression_identifiers && unresolved_identifier.isShort())
{
for (const auto & [name, _] : scope.alias_name_to_table_expression_node)
valid_identifiers_result.insert(Identifier(name));
}
if (unresolved_identifier.isShort())
{
for (const auto & [argument_name, expression] : scope.expression_argument_name_to_node)
{
auto expression_node_type = expression->getNodeType();
if (allow_expression_identifiers && isExpressionNodeType(expression_node_type))
valid_identifiers_result.insert(Identifier(argument_name));
else if (allow_function_identifiers && isFunctionExpressionNodeType(expression_node_type))
valid_identifiers_result.insert(Identifier(argument_name));
else if (allow_table_expression_identifiers && isTableExpressionNodeType(expression_node_type))
valid_identifiers_result.insert(Identifier(argument_name));
}
}
}
void QueryAnalyzer::getScopeWithParentScopesValidIdentifiersForTypoCorrection(
const Identifier & unresolved_identifier,
const IdentifierResolveScope & scope,
bool allow_expression_identifiers,
bool allow_function_identifiers,
bool allow_table_expression_identifiers,
std::unordered_set<Identifier> & valid_identifiers_result)
{
const IdentifierResolveScope * current_scope = &scope;
while (current_scope)
{
getScopeValidIdentifiersForTypoCorrection(unresolved_identifier,
*current_scope,
allow_expression_identifiers,
allow_function_identifiers,
allow_table_expression_identifiers,
valid_identifiers_result);
current_scope = current_scope->parent_scope;
}
}
std::vector<String> QueryAnalyzer::getIdentifierTypoHints(const Identifier & unresolved_identifier, const std::unordered_set<Identifier> & valid_identifiers)
{
std::vector<String> prompting_strings;
prompting_strings.reserve(valid_identifiers.size());
for (const auto & valid_identifier : valid_identifiers)
prompting_strings.push_back(valid_identifier.getFullName());
NamePrompter<1> prompter;
return prompter.getHints(unresolved_identifier.getFullName(), prompting_strings);
}
/** Wrap expression node in tuple element function calls for nested paths.
* Example: Expression node: compound_expression. Nested path: nested_path_1.nested_path_2.
* Result: tupleElement(tupleElement(compound_expression, 'nested_path_1'), 'nested_path_2').
@ -2107,13 +2264,24 @@ QueryTreeNodePtr QueryAnalyzer::tryResolveIdentifierFromTableExpression(const Id
}
if (!result_column || (!match_full_identifier && !compound_identifier))
throw Exception(ErrorCodes::BAD_ARGUMENTS,
"Identifier '{}' cannot be resolved from {}{}. In scope {}",
{
std::string error_message = fmt::format("Identifier '{}' cannot be resolved from {}{}. In scope {}",
identifier.getFullName(),
table_expression_data.table_expression_description,
table_expression_data.table_expression_name.empty() ? "" : " with name " + table_expression_data.table_expression_name,
scope.scope_node->formatASTForErrorMessage());
std::unordered_set<Identifier> valid_identifiers;
getTableExpressionValidIdentifiersForTypoCorrection(identifier,
table_expression_node,
table_expression_data,
valid_identifiers);
auto hints = getIdentifierTypoHints(identifier, valid_identifiers);
appendHintsMessage(error_message, hints);
throw Exception(ErrorCodes::BAD_ARGUMENTS, error_message);
}
QueryTreeNodePtr result_expression = result_column;
bool clone_is_needed = true;
@ -4309,12 +4477,24 @@ ProjectionNames QueryAnalyzer::resolveExpressionNode(QueryTreeNodePtr & node, Id
if (allow_table_expression)
message_clarification = std::string(" or ") + toStringLowercase(IdentifierLookupContext::TABLE_EXPRESSION);
throw Exception(ErrorCodes::UNKNOWN_IDENTIFIER,
"Unknown {}{} identifier '{}' in scope {}",
std::string error_message = fmt::format("Unknown {}{} identifier '{}' in scope {}",
toStringLowercase(IdentifierLookupContext::EXPRESSION),
message_clarification,
unresolved_identifier.getFullName(),
scope.scope_node->formatASTForErrorMessage());
std::unordered_set<Identifier> valid_identifiers;
getScopeWithParentScopesValidIdentifiersForTypoCorrection(unresolved_identifier,
scope,
true,
allow_lambda_expression,
allow_table_expression,
valid_identifiers);
auto hints = getIdentifierTypoHints(unresolved_identifier, valid_identifiers);
appendHintsMessage(error_message, hints);
throw Exception(ErrorCodes::UNKNOWN_IDENTIFIER, error_message);
}
if (node->getNodeType() == QueryTreeNodeType::LIST)
@ -4876,17 +5056,15 @@ void QueryAnalyzer::initializeTableExpressionColumns(const QueryTreeNodePtr & ta
{
table_expression_data.table_name = query_node ? query_node->getCTEName() : union_node->getCTEName();
table_expression_data.table_expression_description = "subquery";
if (table_expression_node->hasAlias())
table_expression_data.table_expression_name = table_expression_node->getAlias();
}
else if (table_function_node)
{
table_expression_data.table_expression_description = "table_function";
if (table_function_node->hasAlias())
table_expression_data.table_expression_name = table_function_node->getAlias();
}
if (table_expression_node->hasAlias())
table_expression_data.table_expression_name = table_expression_node->getAlias();
if (table_node || table_function_node)
{
const auto & storage_snapshot = table_node ? table_node->getStorageSnapshot() : table_function_node->getStorageSnapshot();

View File

@ -1,9 +1,10 @@
#include <IO/WriteHelpers.h>
#include <Common/NamePrompter.h>
namespace DB::detail
namespace DB
{
void appendHintsMessageImpl(String & message, const std::vector<String> & hints)
void appendHintsMessage(String & message, const std::vector<String> & hints)
{
if (hints.empty())
{
@ -12,4 +13,5 @@ void appendHintsMessageImpl(String & message, const std::vector<String> & hints)
message += ". Maybe you meant: " + toString(hints);
}
}

View File

@ -12,6 +12,7 @@
namespace DB
{
template <size_t MaxNumHints>
class NamePrompter
{
@ -90,10 +91,7 @@ private:
}
};
namespace detail
{
void appendHintsMessageImpl(String & message, const std::vector<String> & hints);
}
void appendHintsMessage(String & message, const std::vector<String> & hints);
template <size_t MaxNumHints, typename Self>
class IHints
@ -109,7 +107,7 @@ public:
void appendHintsMessage(String & message, const String & name) const
{
auto hints = getHints(name);
detail::appendHintsMessageImpl(message, hints);
DB::appendHintsMessage(message, hints);
}
IHints() = default;