Normalize function names

This commit is contained in:
Amos Bird 2021-02-14 19:09:36 +08:00
parent c92e613b82
commit 77fd060665
No known key found for this signature in database
GPG Key ID: 80D430DCBECFEDB4
18 changed files with 151 additions and 19 deletions

View File

@ -30,6 +30,10 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}
const String & getAggregateFunctionCanonicalNameIfAny(const String & name)
{
return AggregateFunctionFactory::instance().getCanonicalNameIfAny(name);
}
void AggregateFunctionFactory::registerFunction(const String & name, Value creator_with_properties, CaseSensitiveness case_sensitiveness)
{
@ -41,10 +45,14 @@ void AggregateFunctionFactory::registerFunction(const String & name, Value creat
throw Exception("AggregateFunctionFactory: the aggregate function name '" + name + "' is not unique",
ErrorCodes::LOGICAL_ERROR);
if (case_sensitiveness == CaseInsensitive
&& !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator_with_properties).second)
throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique",
ErrorCodes::LOGICAL_ERROR);
if (case_sensitiveness == CaseInsensitive)
{
auto key = Poco::toLower(name);
if (!case_insensitive_aggregate_functions.emplace(key, creator_with_properties).second)
throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique",
ErrorCodes::LOGICAL_ERROR);
case_insensitive_name_mapping[key] = name;
}
}
static DataTypes convertLowCardinalityTypesToNested(const DataTypes & types)

View File

@ -35,6 +35,8 @@ protected:
return name;
}
std::unordered_map<String, String> case_insensitive_name_mapping;
public:
/// For compatibility with SQL, it's possible to specify that certain function name is case insensitive.
enum CaseSensitiveness
@ -68,9 +70,12 @@ public:
factory_name + ": the alias name '" + alias_name + "' is already registered as real name", ErrorCodes::LOGICAL_ERROR);
if (case_sensitiveness == CaseInsensitive)
{
if (!case_insensitive_aliases.emplace(alias_name_lowercase, real_dict_name).second)
throw Exception(
factory_name + ": case insensitive alias name '" + alias_name + "' is not unique", ErrorCodes::LOGICAL_ERROR);
case_insensitive_name_mapping[alias_name_lowercase] = real_name;
}
if (!aliases.emplace(alias_name, real_dict_name).second)
throw Exception(factory_name + ": alias name '" + alias_name + "' is not unique", ErrorCodes::LOGICAL_ERROR);
@ -111,6 +116,15 @@ public:
return getMap().count(name) || getCaseInsensitiveMap().count(name) || isAlias(name);
}
/// Return the canonical name (the name used in registration) if it's different from `name`.
const String & getCanonicalNameIfAny(const String & name) const
{
auto it = case_insensitive_name_mapping.find(Poco::toLower(name));
if (it != case_insensitive_name_mapping.end())
return it->second;
return name;
}
virtual ~IFactoryWithAliases() override {}
private:

View File

@ -21,6 +21,10 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}
const String & getFunctionCanonicalNameIfAny(const String & name)
{
return FunctionFactory::instance().getCanonicalNameIfAny(name);
}
void FunctionFactory::registerFunction(const
std::string & name,
@ -36,10 +40,13 @@ void FunctionFactory::registerFunction(const
throw Exception("FunctionFactory: the function name '" + name + "' is already registered as alias",
ErrorCodes::LOGICAL_ERROR);
if (case_sensitiveness == CaseInsensitive
&& !case_insensitive_functions.emplace(function_name_lowercase, creator).second)
throw Exception("FunctionFactory: the case insensitive function name '" + name + "' is not unique",
ErrorCodes::LOGICAL_ERROR);
if (case_sensitiveness == CaseInsensitive)
{
if (!case_insensitive_functions.emplace(function_name_lowercase, creator).second)
throw Exception("FunctionFactory: the case insensitive function name '" + name + "' is not unique",
ErrorCodes::LOGICAL_ERROR);
case_insensitive_name_mapping[function_name_lowercase] = name;
}
}

View File

@ -8,7 +8,7 @@ namespace DB
void registerFunctionsRound(FunctionFactory & factory)
{
factory.registerFunction<FunctionRound>("round", FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionRoundBankers>("roundBankers", FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionRoundBankers>("roundBankers", FunctionFactory::CaseSensitive);
factory.registerFunction<FunctionFloor>("floor", FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionCeil>("ceil", FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionTrunc>("trunc", FunctionFactory::CaseInsensitive);

View File

@ -18,7 +18,7 @@ namespace DB
void registerFunctionExtractAllGroupsVertical(FunctionFactory & factory)
{
factory.registerFunction<FunctionExtractAllGroups<VerticalImpl>>();
factory.registerAlias("extractAllGroups", VerticalImpl::Name, FunctionFactory::CaseInsensitive);
factory.registerAlias("extractAllGroups", VerticalImpl::Name, FunctionFactory::CaseSensitive);
}
}

View File

@ -0,0 +1,18 @@
#include <Interpreters/FunctionNameNormalizer.h>
namespace DB
{
const String & getFunctionCanonicalNameIfAny(const String & name);
const String & getAggregateFunctionCanonicalNameIfAny(const String & name);
void FunctionNameNormalizer::visit(ASTPtr & ast)
{
if (auto * node_func = ast->as<ASTFunction>())
node_func->name = getAggregateFunctionCanonicalNameIfAny(getFunctionCanonicalNameIfAny(node_func->name));
for (auto & child : ast->children)
visit(child);
}
}

View File

@ -0,0 +1,14 @@
#pragma once
#include <Parsers/IAST.h>
#include <Parsers/ASTFunction.h>
namespace DB
{
struct FunctionNameNormalizer
{
static void visit(ASTPtr &);
};
}

View File

@ -442,10 +442,10 @@ ASTPtr MutationsInterpreter::prepare(bool dry_run)
auto type_literal = std::make_shared<ASTLiteral>(columns_desc.getPhysical(column).type->getName());
const auto & update_expr = kv.second;
auto updated_column = makeASTFunction("cast",
auto updated_column = makeASTFunction("CAST",
makeASTFunction("if",
getPartitionAndPredicateExpressionForMutationCommand(command),
makeASTFunction("cast",
makeASTFunction("CAST",
update_expr->clone(),
type_literal),
std::make_shared<ASTIdentifier>(column)),

View File

@ -8,6 +8,7 @@
#include <Interpreters/ArrayJoinedColumnsVisitor.h>
#include <Interpreters/TranslateQualifiedNamesVisitor.h>
#include <Interpreters/Context.h>
#include <Interpreters/FunctionNameNormalizer.h>
#include <Interpreters/MarkTableIdentifiersVisitor.h>
#include <Interpreters/QueryNormalizer.h>
#include <Interpreters/ExecuteScalarSubqueriesVisitor.h>
@ -934,6 +935,9 @@ void TreeRewriter::normalize(ASTPtr & query, Aliases & aliases, const Settings &
MarkTableIdentifiersVisitor::Data identifiers_data{aliases};
MarkTableIdentifiersVisitor(identifiers_data).visit(query);
/// Rewrite function names to their canonical ones.
FunctionNameNormalizer().visit(query);
/// Common subexpression elimination. Rewrite rules.
QueryNormalizer::Data normalizer_data(aliases, settings);
QueryNormalizer(normalizer_data).visit(query);

View File

@ -20,7 +20,7 @@ namespace ErrorCodes
ASTPtr addTypeConversionToAST(ASTPtr && ast, const String & type_name)
{
auto func = makeASTFunction("cast", ast, std::make_shared<ASTLiteral>(type_name));
auto func = makeASTFunction("CAST", ast, std::make_shared<ASTLiteral>(type_name));
if (ASTWithAlias * ast_with_alias = dynamic_cast<ASTWithAlias *>(ast.get()))
{

View File

@ -43,7 +43,7 @@ void addDefaultRequiredExpressionsRecursively(const Block & block, const String
RequiredSourceColumnsVisitor(columns_context).visit(column_default_expr);
NameSet required_columns_names = columns_context.requiredColumns();
auto cast_func = makeASTFunction("cast", column_default_expr, std::make_shared<ASTLiteral>(columns.get(required_column).type->getName()));
auto cast_func = makeASTFunction("CAST", column_default_expr, std::make_shared<ASTLiteral>(columns.get(required_column).type->getName()));
default_expr_list_accum->children.emplace_back(setAlias(cast_func, required_column));
added_columns.emplace(required_column);

View File

@ -626,7 +626,7 @@ void ConstantExpressionTemplate::TemplateStructure::addNodesToCastResult(const I
expr = makeASTFunction("assumeNotNull", std::move(expr));
}
expr = makeASTFunction("cast", std::move(expr), std::make_shared<ASTLiteral>(result_column_type.getName()));
expr = makeASTFunction("CAST", std::move(expr), std::make_shared<ASTLiteral>(result_column_type.getName()));
if (null_as_default)
{

View File

@ -217,7 +217,7 @@ def test_mysql_replacement_query(mysql_client, server_address):
--password=123 -e "select database();"
'''.format(host=server_address, port=server_port), demux=True)
assert code == 0
assert stdout.decode() == 'database()\ndefault\n'
assert stdout.decode() == 'DATABASE()\ndefault\n'
code, (stdout, stderr) = mysql_client.exec_run('''
mysql --protocol tcp -h {host} -P {port} default -u default

View File

@ -114,7 +114,7 @@ FROM
(
SELECT
1 AS id,
identity(cast(1, \'UInt8\')) AS subquery
identity(CAST(1, \'UInt8\')) AS subquery
WHERE subquery = 1
)
WHERE subquery = 1

View File

@ -2,7 +2,7 @@ SELECT 1
WHERE 0
SELECT 1
SELECT 1
WHERE (1 IN (0, 2)) AND (2 = (identity(cast(2, \'UInt8\')) AS subquery))
WHERE (1 IN (0, 2)) AND (2 = (identity(CAST(2, \'UInt8\')) AS subquery))
SELECT 1
WHERE 1 IN (
(

View File

@ -5,7 +5,7 @@ SELECT (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n) FO
1,10
EXPLAIN SYNTAX SELECT (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n);
SELECT
identity(cast(0, \'UInt64\')) AS n,
identity(CAST(0, \'UInt64\')) AS n,
toUInt64(10 / n)
SELECT * FROM (WITH (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n) as q SELECT * FROM system.one WHERE q > 0);
0

View File

@ -0,0 +1,66 @@
SELECT
CAST(1, 'INT'),
ceil(1),
ceil(1),
char(49),
CHAR_LENGTH('1'),
CHARACTER_LENGTH('1'),
coalesce(1),
concat('1', '1'),
corr(1, 1),
cos(1),
count(),
covarPop(1, 1),
covarSamp(1, 1),
DATABASE(),
dateDiff('DAY', toDate('2020-10-24'), toDate('2019-10-24')),
exp(1),
arrayFlatten([[1]]),
floor(1),
FQDN(),
greatest(1),
1,
ifNull(1, 1),
lower('A'),
least(1),
length('1'),
log(1),
position('1', '1'),
log(1),
log10(1),
log2(1),
lower('A'),
max(1),
substring('123', 1, 1),
min(1),
1 % 1,
NOT 1,
now(),
now64(),
nullIf(1, 1),
pi(),
position('123', '2'),
pow(1, 1),
pow(1, 1),
rand(),
replaceAll('1', '1', '2'),
reverse('123'),
round(1),
sin(1),
sqrt(1),
stddevPop(1),
stddevSamp(1),
substring('123', 2),
substring('123', 2),
count(),
tan(1),
tanh(1),
trunc(1),
trunc(1),
upper('A'),
upper('A'),
currentUser(),
varPop(1),
varSamp(1),
toWeek(toDate('2020-10-24')),
toYearWeek(toDate('2020-10-24'))

View File

@ -0,0 +1 @@
EXPLAIN SYNTAX SELECT CAST(1 AS INT), CEIL(1), CEILING(1), CHAR(49), CHAR_LENGTH('1'), CHARACTER_LENGTH('1'), COALESCE(1), CONCAT('1', '1'), CORR(1, 1), COS(1), COUNT(1), COVAR_POP(1, 1), COVAR_SAMP(1, 1), DATABASE(), DATEDIFF('DAY', toDate('2020-10-24'), toDate('2019-10-24')), EXP(1), FLATTEN([[1]]), FLOOR(1), FQDN(), GREATEST(1), IF(1, 1, 1), IFNULL(1, 1), LCASE('A'), LEAST(1), LENGTH('1'), LN(1), LOCATE('1', '1'), LOG(1), LOG10(1), LOG2(1), LOWER('A'), MAX(1), MID('123', 1, 1), MIN(1), MOD(1, 1), NOT(1), NOW(), NOW64(), NULLIF(1, 1), PI(), POSITION('123', '2'), POW(1, 1), POWER(1, 1), RAND(), REPLACE('1', '1', '2'), REVERSE('123'), ROUND(1), SIN(1), SQRT(1), STDDEV_POP(1), STDDEV_SAMP(1), SUBSTR('123', 2), SUBSTRING('123', 2), SUM(1), TAN(1), TANH(1), TRUNC(1), TRUNCATE(1), UCASE('A'), UPPER('A'), USER(), VAR_POP(1), VAR_SAMP(1), WEEK(toDate('2020-10-24')), YEARWEEK(toDate('2020-10-24')) format TSVRaw;