mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-20 16:50:48 +00:00
produce hints for typo functions and types
This commit is contained in:
parent
f2ded6a0ae
commit
5e62a0825a
@ -128,7 +128,11 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
|
||||
return combinator->transformAggregateFunction(nested_function, argument_types, parameters);
|
||||
}
|
||||
|
||||
throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
|
||||
auto hints = this->getHints(name);
|
||||
if (!hints.empty())
|
||||
throw Exception("Unknown aggregate function " + name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
|
||||
else
|
||||
throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/NamePrompter.h>
|
||||
#include <Core/Types.h>
|
||||
#include <Poco/String.h>
|
||||
|
||||
@ -105,6 +106,12 @@ public:
|
||||
return aliases.count(name) || case_insensitive_aliases.count(name);
|
||||
}
|
||||
|
||||
std::vector<String> getHints(const String & name) const
|
||||
{
|
||||
static const auto registeredNames = getAllRegisteredNames();
|
||||
return prompter.getHints(name, registeredNames);
|
||||
}
|
||||
|
||||
virtual ~IFactoryWithAliases() {}
|
||||
|
||||
private:
|
||||
@ -120,6 +127,12 @@ private:
|
||||
|
||||
/// Case insensitive aliases
|
||||
AliasMap case_insensitive_aliases;
|
||||
|
||||
/**
|
||||
* prompter for names, if a person makes a typo for some function or type, it
|
||||
* helps to find best possible match (in particular, edit distance is one or two symbols)
|
||||
*/
|
||||
NamePrompter</*MistakeFactor=*/2, /*MaxNumHints=*/2> prompter;
|
||||
};
|
||||
|
||||
}
|
||||
|
95
dbms/src/Common/NamePrompter.h
Normal file
95
dbms/src/Common/NamePrompter.h
Normal file
@ -0,0 +1,95 @@
|
||||
#pragma once
|
||||
|
||||
#include <Core/Types.h>
|
||||
|
||||
#include <cctype>
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
template <size_t MistakeFactor, size_t MaxNumHints>
|
||||
class NamePrompter
|
||||
{
|
||||
public:
|
||||
using DistanceIndex = std::pair<size_t, size_t>;
|
||||
using DistanceIndexQueue = std::priority_queue<DistanceIndex>;
|
||||
|
||||
static std::vector<String> getHints(const String & name, const std::vector<String> & prompting_strings)
|
||||
{
|
||||
DistanceIndexQueue queue;
|
||||
for (size_t i = 0; i < prompting_strings.size(); ++i)
|
||||
appendToQueue(i, name, queue, prompting_strings);
|
||||
return release(queue, prompting_strings);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
static size_t LevenshteinDistance(const String & lhs, const String & rhs)
|
||||
{
|
||||
size_t n = lhs.size();
|
||||
size_t m = rhs.size();
|
||||
std::vector<std::vector<size_t>> d(n + 1, std::vector<size_t>(m + 1));
|
||||
|
||||
for (size_t i = 1; i <= n; ++i)
|
||||
d[i][0] = i;
|
||||
|
||||
for (size_t i = 1; i <= m; ++i)
|
||||
d[0][i] = i;
|
||||
|
||||
for (size_t j = 1; j <= m; ++j)
|
||||
{
|
||||
for (size_t i = 1; i <= n; ++i)
|
||||
{
|
||||
if (std::tolower(lhs[i - 1]) == std::tolower(rhs[j - 1]))
|
||||
{
|
||||
d[i][j] = d[i - 1][j - 1];
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t dist1 = d[i - 1][j] + 1;
|
||||
size_t dist2 = d[i][j - 1] + 1;
|
||||
size_t dist3 = d[i - 1][j - 1] + 1;
|
||||
d[i][j] = std::min(dist1, std::min(dist2, dist3));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return d[n][m];
|
||||
}
|
||||
|
||||
static void appendToQueue(size_t ind, const String & name, DistanceIndexQueue & queue, const std::vector<String> & prompting_strings)
|
||||
{
|
||||
std::cout << prompting_strings[ind] << std::endl;
|
||||
if (prompting_strings[ind].size() <= name.size() + MistakeFactor && prompting_strings[ind].size() + MistakeFactor >= name.size())
|
||||
{
|
||||
size_t distance = LevenshteinDistance(prompting_strings[ind], name);
|
||||
if (distance <= MistakeFactor) {
|
||||
queue.emplace(distance, ind);
|
||||
if (queue.size() > MaxNumHints)
|
||||
queue.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<String> release(DistanceIndexQueue & queue, const std::vector<String> & prompting_strings)
|
||||
{
|
||||
std::vector<String> ans;
|
||||
ans.reserve(queue.size());
|
||||
while (!queue.empty())
|
||||
{
|
||||
auto top = queue.top();
|
||||
queue.pop();
|
||||
ans.push_back(prompting_strings[top.second]);
|
||||
}
|
||||
std::reverse(ans.begin(), ans.end());
|
||||
return ans;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
@ -7,7 +7,7 @@
|
||||
#include <Common/typeid_cast.h>
|
||||
#include <Poco/String.h>
|
||||
#include <Common/StringUtils/StringUtils.h>
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -87,7 +87,11 @@ DataTypePtr DataTypeFactory::get(const String & family_name_param, const ASTPtr
|
||||
return it->second(parameters);
|
||||
}
|
||||
|
||||
throw Exception("Unknown data type family: " + family_name, ErrorCodes::UNKNOWN_TYPE);
|
||||
auto hints = this->getHints(family_name);
|
||||
if (!hints.empty())
|
||||
throw Exception("Unknown data type family: " + family_name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_TYPE);
|
||||
else
|
||||
throw Exception("Unknown data type family: " + family_name, ErrorCodes::UNKNOWN_TYPE);
|
||||
}
|
||||
|
||||
|
||||
|
@ -6,6 +6,8 @@
|
||||
|
||||
#include <Poco/String.h>
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
@ -43,7 +45,13 @@ FunctionBuilderPtr FunctionFactory::get(
|
||||
{
|
||||
auto res = tryGet(name, context);
|
||||
if (!res)
|
||||
throw Exception("Unknown function " + name, ErrorCodes::UNKNOWN_FUNCTION);
|
||||
{
|
||||
auto hints = this->getHints(name);
|
||||
if (!hints.empty())
|
||||
throw Exception("Unknown function " + name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_FUNCTION);
|
||||
else
|
||||
throw Exception("Unknown function " + name, ErrorCodes::UNKNOWN_FUNCTION);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -357,7 +357,18 @@ void ActionsVisitor::visit(const ASTPtr & ast)
|
||||
? context.getQueryContext()
|
||||
: context;
|
||||
|
||||
const FunctionBuilderPtr & function_builder = FunctionFactory::instance().get(node->name, function_context);
|
||||
FunctionBuilderPtr function_builder;
|
||||
try
|
||||
{
|
||||
function_builder = FunctionFactory::instance().get(node->name, function_context);
|
||||
}
|
||||
catch (DB::Exception & e)
|
||||
{
|
||||
auto hints = AggregateFunctionFactory::instance().getHints(node->name);
|
||||
if (!hints.empty())
|
||||
e.addMessage("Or unknown aggregate function " + node->name + ". Maybe you meant: " + toString(hints));
|
||||
e.rethrow();
|
||||
}
|
||||
|
||||
Names argument_names;
|
||||
DataTypes argument_types;
|
||||
|
Loading…
Reference in New Issue
Block a user