mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-24 00:22:29 +00:00
Normalize function names
This commit is contained in:
parent
c92e613b82
commit
77fd060665
@ -30,6 +30,10 @@ namespace ErrorCodes
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
const String & getAggregateFunctionCanonicalNameIfAny(const String & name)
|
||||
{
|
||||
return AggregateFunctionFactory::instance().getCanonicalNameIfAny(name);
|
||||
}
|
||||
|
||||
void AggregateFunctionFactory::registerFunction(const String & name, Value creator_with_properties, CaseSensitiveness case_sensitiveness)
|
||||
{
|
||||
@ -41,10 +45,14 @@ void AggregateFunctionFactory::registerFunction(const String & name, Value creat
|
||||
throw Exception("AggregateFunctionFactory: the aggregate function name '" + name + "' is not unique",
|
||||
ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
if (case_sensitiveness == CaseInsensitive
|
||||
&& !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator_with_properties).second)
|
||||
throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique",
|
||||
ErrorCodes::LOGICAL_ERROR);
|
||||
if (case_sensitiveness == CaseInsensitive)
|
||||
{
|
||||
auto key = Poco::toLower(name);
|
||||
if (!case_insensitive_aggregate_functions.emplace(key, creator_with_properties).second)
|
||||
throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique",
|
||||
ErrorCodes::LOGICAL_ERROR);
|
||||
case_insensitive_name_mapping[key] = name;
|
||||
}
|
||||
}
|
||||
|
||||
static DataTypes convertLowCardinalityTypesToNested(const DataTypes & types)
|
||||
|
@ -35,6 +35,8 @@ protected:
|
||||
return name;
|
||||
}
|
||||
|
||||
std::unordered_map<String, String> case_insensitive_name_mapping;
|
||||
|
||||
public:
|
||||
/// For compatibility with SQL, it's possible to specify that certain function name is case insensitive.
|
||||
enum CaseSensitiveness
|
||||
@ -68,9 +70,12 @@ public:
|
||||
factory_name + ": the alias name '" + alias_name + "' is already registered as real name", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
if (case_sensitiveness == CaseInsensitive)
|
||||
{
|
||||
if (!case_insensitive_aliases.emplace(alias_name_lowercase, real_dict_name).second)
|
||||
throw Exception(
|
||||
factory_name + ": case insensitive alias name '" + alias_name + "' is not unique", ErrorCodes::LOGICAL_ERROR);
|
||||
case_insensitive_name_mapping[alias_name_lowercase] = real_name;
|
||||
}
|
||||
|
||||
if (!aliases.emplace(alias_name, real_dict_name).second)
|
||||
throw Exception(factory_name + ": alias name '" + alias_name + "' is not unique", ErrorCodes::LOGICAL_ERROR);
|
||||
@ -111,6 +116,15 @@ public:
|
||||
return getMap().count(name) || getCaseInsensitiveMap().count(name) || isAlias(name);
|
||||
}
|
||||
|
||||
/// Return the canonical name (the name used in registration) if it's different from `name`.
|
||||
const String & getCanonicalNameIfAny(const String & name) const
|
||||
{
|
||||
auto it = case_insensitive_name_mapping.find(Poco::toLower(name));
|
||||
if (it != case_insensitive_name_mapping.end())
|
||||
return it->second;
|
||||
return name;
|
||||
}
|
||||
|
||||
virtual ~IFactoryWithAliases() override {}
|
||||
|
||||
private:
|
||||
|
@ -21,6 +21,10 @@ namespace ErrorCodes
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
const String & getFunctionCanonicalNameIfAny(const String & name)
|
||||
{
|
||||
return FunctionFactory::instance().getCanonicalNameIfAny(name);
|
||||
}
|
||||
|
||||
void FunctionFactory::registerFunction(const
|
||||
std::string & name,
|
||||
@ -36,10 +40,13 @@ void FunctionFactory::registerFunction(const
|
||||
throw Exception("FunctionFactory: the function name '" + name + "' is already registered as alias",
|
||||
ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
if (case_sensitiveness == CaseInsensitive
|
||||
&& !case_insensitive_functions.emplace(function_name_lowercase, creator).second)
|
||||
throw Exception("FunctionFactory: the case insensitive function name '" + name + "' is not unique",
|
||||
ErrorCodes::LOGICAL_ERROR);
|
||||
if (case_sensitiveness == CaseInsensitive)
|
||||
{
|
||||
if (!case_insensitive_functions.emplace(function_name_lowercase, creator).second)
|
||||
throw Exception("FunctionFactory: the case insensitive function name '" + name + "' is not unique",
|
||||
ErrorCodes::LOGICAL_ERROR);
|
||||
case_insensitive_name_mapping[function_name_lowercase] = name;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -8,7 +8,7 @@ namespace DB
|
||||
void registerFunctionsRound(FunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction<FunctionRound>("round", FunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction<FunctionRoundBankers>("roundBankers", FunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction<FunctionRoundBankers>("roundBankers", FunctionFactory::CaseSensitive);
|
||||
factory.registerFunction<FunctionFloor>("floor", FunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction<FunctionCeil>("ceil", FunctionFactory::CaseInsensitive);
|
||||
factory.registerFunction<FunctionTrunc>("trunc", FunctionFactory::CaseInsensitive);
|
||||
|
@ -18,7 +18,7 @@ namespace DB
|
||||
void registerFunctionExtractAllGroupsVertical(FunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction<FunctionExtractAllGroups<VerticalImpl>>();
|
||||
factory.registerAlias("extractAllGroups", VerticalImpl::Name, FunctionFactory::CaseInsensitive);
|
||||
factory.registerAlias("extractAllGroups", VerticalImpl::Name, FunctionFactory::CaseSensitive);
|
||||
}
|
||||
|
||||
}
|
||||
|
18
src/Interpreters/FunctionNameNormalizer.cpp
Normal file
18
src/Interpreters/FunctionNameNormalizer.cpp
Normal file
@ -0,0 +1,18 @@
|
||||
#include <Interpreters/FunctionNameNormalizer.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
const String & getFunctionCanonicalNameIfAny(const String & name);
|
||||
const String & getAggregateFunctionCanonicalNameIfAny(const String & name);
|
||||
|
||||
void FunctionNameNormalizer::visit(ASTPtr & ast)
|
||||
{
|
||||
if (auto * node_func = ast->as<ASTFunction>())
|
||||
node_func->name = getAggregateFunctionCanonicalNameIfAny(getFunctionCanonicalNameIfAny(node_func->name));
|
||||
|
||||
for (auto & child : ast->children)
|
||||
visit(child);
|
||||
}
|
||||
|
||||
}
|
14
src/Interpreters/FunctionNameNormalizer.h
Normal file
14
src/Interpreters/FunctionNameNormalizer.h
Normal file
@ -0,0 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <Parsers/IAST.h>
|
||||
#include <Parsers/ASTFunction.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
struct FunctionNameNormalizer
|
||||
{
|
||||
static void visit(ASTPtr &);
|
||||
};
|
||||
|
||||
}
|
@ -442,10 +442,10 @@ ASTPtr MutationsInterpreter::prepare(bool dry_run)
|
||||
auto type_literal = std::make_shared<ASTLiteral>(columns_desc.getPhysical(column).type->getName());
|
||||
|
||||
const auto & update_expr = kv.second;
|
||||
auto updated_column = makeASTFunction("cast",
|
||||
auto updated_column = makeASTFunction("CAST",
|
||||
makeASTFunction("if",
|
||||
getPartitionAndPredicateExpressionForMutationCommand(command),
|
||||
makeASTFunction("cast",
|
||||
makeASTFunction("CAST",
|
||||
update_expr->clone(),
|
||||
type_literal),
|
||||
std::make_shared<ASTIdentifier>(column)),
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <Interpreters/ArrayJoinedColumnsVisitor.h>
|
||||
#include <Interpreters/TranslateQualifiedNamesVisitor.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Interpreters/FunctionNameNormalizer.h>
|
||||
#include <Interpreters/MarkTableIdentifiersVisitor.h>
|
||||
#include <Interpreters/QueryNormalizer.h>
|
||||
#include <Interpreters/ExecuteScalarSubqueriesVisitor.h>
|
||||
@ -934,6 +935,9 @@ void TreeRewriter::normalize(ASTPtr & query, Aliases & aliases, const Settings &
|
||||
MarkTableIdentifiersVisitor::Data identifiers_data{aliases};
|
||||
MarkTableIdentifiersVisitor(identifiers_data).visit(query);
|
||||
|
||||
/// Rewrite function names to their canonical ones.
|
||||
FunctionNameNormalizer().visit(query);
|
||||
|
||||
/// Common subexpression elimination. Rewrite rules.
|
||||
QueryNormalizer::Data normalizer_data(aliases, settings);
|
||||
QueryNormalizer(normalizer_data).visit(query);
|
||||
|
@ -20,7 +20,7 @@ namespace ErrorCodes
|
||||
|
||||
ASTPtr addTypeConversionToAST(ASTPtr && ast, const String & type_name)
|
||||
{
|
||||
auto func = makeASTFunction("cast", ast, std::make_shared<ASTLiteral>(type_name));
|
||||
auto func = makeASTFunction("CAST", ast, std::make_shared<ASTLiteral>(type_name));
|
||||
|
||||
if (ASTWithAlias * ast_with_alias = dynamic_cast<ASTWithAlias *>(ast.get()))
|
||||
{
|
||||
|
@ -43,7 +43,7 @@ void addDefaultRequiredExpressionsRecursively(const Block & block, const String
|
||||
RequiredSourceColumnsVisitor(columns_context).visit(column_default_expr);
|
||||
NameSet required_columns_names = columns_context.requiredColumns();
|
||||
|
||||
auto cast_func = makeASTFunction("cast", column_default_expr, std::make_shared<ASTLiteral>(columns.get(required_column).type->getName()));
|
||||
auto cast_func = makeASTFunction("CAST", column_default_expr, std::make_shared<ASTLiteral>(columns.get(required_column).type->getName()));
|
||||
default_expr_list_accum->children.emplace_back(setAlias(cast_func, required_column));
|
||||
added_columns.emplace(required_column);
|
||||
|
||||
|
@ -626,7 +626,7 @@ void ConstantExpressionTemplate::TemplateStructure::addNodesToCastResult(const I
|
||||
expr = makeASTFunction("assumeNotNull", std::move(expr));
|
||||
}
|
||||
|
||||
expr = makeASTFunction("cast", std::move(expr), std::make_shared<ASTLiteral>(result_column_type.getName()));
|
||||
expr = makeASTFunction("CAST", std::move(expr), std::make_shared<ASTLiteral>(result_column_type.getName()));
|
||||
|
||||
if (null_as_default)
|
||||
{
|
||||
|
@ -217,7 +217,7 @@ def test_mysql_replacement_query(mysql_client, server_address):
|
||||
--password=123 -e "select database();"
|
||||
'''.format(host=server_address, port=server_port), demux=True)
|
||||
assert code == 0
|
||||
assert stdout.decode() == 'database()\ndefault\n'
|
||||
assert stdout.decode() == 'DATABASE()\ndefault\n'
|
||||
|
||||
code, (stdout, stderr) = mysql_client.exec_run('''
|
||||
mysql --protocol tcp -h {host} -P {port} default -u default
|
||||
|
@ -114,7 +114,7 @@ FROM
|
||||
(
|
||||
SELECT
|
||||
1 AS id,
|
||||
identity(cast(1, \'UInt8\')) AS subquery
|
||||
identity(CAST(1, \'UInt8\')) AS subquery
|
||||
WHERE subquery = 1
|
||||
)
|
||||
WHERE subquery = 1
|
||||
|
@ -2,7 +2,7 @@ SELECT 1
|
||||
WHERE 0
|
||||
SELECT 1
|
||||
SELECT 1
|
||||
WHERE (1 IN (0, 2)) AND (2 = (identity(cast(2, \'UInt8\')) AS subquery))
|
||||
WHERE (1 IN (0, 2)) AND (2 = (identity(CAST(2, \'UInt8\')) AS subquery))
|
||||
SELECT 1
|
||||
WHERE 1 IN (
|
||||
(
|
||||
|
@ -5,7 +5,7 @@ SELECT (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n) FO
|
||||
1,10
|
||||
EXPLAIN SYNTAX SELECT (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n);
|
||||
SELECT
|
||||
identity(cast(0, \'UInt64\')) AS n,
|
||||
identity(CAST(0, \'UInt64\')) AS n,
|
||||
toUInt64(10 / n)
|
||||
SELECT * FROM (WITH (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n) as q SELECT * FROM system.one WHERE q > 0);
|
||||
0
|
||||
|
@ -0,0 +1,66 @@
|
||||
SELECT
|
||||
CAST(1, 'INT'),
|
||||
ceil(1),
|
||||
ceil(1),
|
||||
char(49),
|
||||
CHAR_LENGTH('1'),
|
||||
CHARACTER_LENGTH('1'),
|
||||
coalesce(1),
|
||||
concat('1', '1'),
|
||||
corr(1, 1),
|
||||
cos(1),
|
||||
count(),
|
||||
covarPop(1, 1),
|
||||
covarSamp(1, 1),
|
||||
DATABASE(),
|
||||
dateDiff('DAY', toDate('2020-10-24'), toDate('2019-10-24')),
|
||||
exp(1),
|
||||
arrayFlatten([[1]]),
|
||||
floor(1),
|
||||
FQDN(),
|
||||
greatest(1),
|
||||
1,
|
||||
ifNull(1, 1),
|
||||
lower('A'),
|
||||
least(1),
|
||||
length('1'),
|
||||
log(1),
|
||||
position('1', '1'),
|
||||
log(1),
|
||||
log10(1),
|
||||
log2(1),
|
||||
lower('A'),
|
||||
max(1),
|
||||
substring('123', 1, 1),
|
||||
min(1),
|
||||
1 % 1,
|
||||
NOT 1,
|
||||
now(),
|
||||
now64(),
|
||||
nullIf(1, 1),
|
||||
pi(),
|
||||
position('123', '2'),
|
||||
pow(1, 1),
|
||||
pow(1, 1),
|
||||
rand(),
|
||||
replaceAll('1', '1', '2'),
|
||||
reverse('123'),
|
||||
round(1),
|
||||
sin(1),
|
||||
sqrt(1),
|
||||
stddevPop(1),
|
||||
stddevSamp(1),
|
||||
substring('123', 2),
|
||||
substring('123', 2),
|
||||
count(),
|
||||
tan(1),
|
||||
tanh(1),
|
||||
trunc(1),
|
||||
trunc(1),
|
||||
upper('A'),
|
||||
upper('A'),
|
||||
currentUser(),
|
||||
varPop(1),
|
||||
varSamp(1),
|
||||
toWeek(toDate('2020-10-24')),
|
||||
toYearWeek(toDate('2020-10-24'))
|
@ -0,0 +1 @@
|
||||
EXPLAIN SYNTAX SELECT CAST(1 AS INT), CEIL(1), CEILING(1), CHAR(49), CHAR_LENGTH('1'), CHARACTER_LENGTH('1'), COALESCE(1), CONCAT('1', '1'), CORR(1, 1), COS(1), COUNT(1), COVAR_POP(1, 1), COVAR_SAMP(1, 1), DATABASE(), DATEDIFF('DAY', toDate('2020-10-24'), toDate('2019-10-24')), EXP(1), FLATTEN([[1]]), FLOOR(1), FQDN(), GREATEST(1), IF(1, 1, 1), IFNULL(1, 1), LCASE('A'), LEAST(1), LENGTH('1'), LN(1), LOCATE('1', '1'), LOG(1), LOG10(1), LOG2(1), LOWER('A'), MAX(1), MID('123', 1, 1), MIN(1), MOD(1, 1), NOT(1), NOW(), NOW64(), NULLIF(1, 1), PI(), POSITION('123', '2'), POW(1, 1), POWER(1, 1), RAND(), REPLACE('1', '1', '2'), REVERSE('123'), ROUND(1), SIN(1), SQRT(1), STDDEV_POP(1), STDDEV_SAMP(1), SUBSTR('123', 2), SUBSTRING('123', 2), SUM(1), TAN(1), TANH(1), TRUNC(1), TRUNCATE(1), UCASE('A'), UPPER('A'), USER(), VAR_POP(1), VAR_SAMP(1), WEEK(toDate('2020-10-24')), YEARWEEK(toDate('2020-10-24')) format TSVRaw;
|
Loading…
Reference in New Issue
Block a user