diff --git a/src/AggregateFunctions/AggregateFunctionFactory.cpp b/src/AggregateFunctions/AggregateFunctionFactory.cpp index 5fc690d59f2..061077dd8fa 100644 --- a/src/AggregateFunctions/AggregateFunctionFactory.cpp +++ b/src/AggregateFunctions/AggregateFunctionFactory.cpp @@ -30,6 +30,10 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; } +const String & getAggregateFunctionCanonicalNameIfAny(const String & name) +{ + return AggregateFunctionFactory::instance().getCanonicalNameIfAny(name); +} void AggregateFunctionFactory::registerFunction(const String & name, Value creator_with_properties, CaseSensitiveness case_sensitiveness) { @@ -41,10 +45,14 @@ void AggregateFunctionFactory::registerFunction(const String & name, Value creat throw Exception("AggregateFunctionFactory: the aggregate function name '" + name + "' is not unique", ErrorCodes::LOGICAL_ERROR); - if (case_sensitiveness == CaseInsensitive - && !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator_with_properties).second) - throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique", - ErrorCodes::LOGICAL_ERROR); + if (case_sensitiveness == CaseInsensitive) + { + auto key = Poco::toLower(name); + if (!case_insensitive_aggregate_functions.emplace(key, creator_with_properties).second) + throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique", + ErrorCodes::LOGICAL_ERROR); + case_insensitive_name_mapping[key] = name; + } } static DataTypes convertLowCardinalityTypesToNested(const DataTypes & types) diff --git a/src/Common/IFactoryWithAliases.h b/src/Common/IFactoryWithAliases.h index 49c03049b92..5ef795c92d0 100644 --- a/src/Common/IFactoryWithAliases.h +++ b/src/Common/IFactoryWithAliases.h @@ -35,6 +35,8 @@ protected: return name; } + std::unordered_map case_insensitive_name_mapping; + public: /// For compatibility with SQL, it's possible to specify that certain function name is case insensitive. enum CaseSensitiveness @@ -68,9 +70,12 @@ public: factory_name + ": the alias name '" + alias_name + "' is already registered as real name", ErrorCodes::LOGICAL_ERROR); if (case_sensitiveness == CaseInsensitive) + { if (!case_insensitive_aliases.emplace(alias_name_lowercase, real_dict_name).second) throw Exception( factory_name + ": case insensitive alias name '" + alias_name + "' is not unique", ErrorCodes::LOGICAL_ERROR); + case_insensitive_name_mapping[alias_name_lowercase] = real_name; + } if (!aliases.emplace(alias_name, real_dict_name).second) throw Exception(factory_name + ": alias name '" + alias_name + "' is not unique", ErrorCodes::LOGICAL_ERROR); @@ -111,6 +116,15 @@ public: return getMap().count(name) || getCaseInsensitiveMap().count(name) || isAlias(name); } + /// Return the canonical name (the name used in registration) if it's different from `name`. + const String & getCanonicalNameIfAny(const String & name) const + { + auto it = case_insensitive_name_mapping.find(Poco::toLower(name)); + if (it != case_insensitive_name_mapping.end()) + return it->second; + return name; + } + virtual ~IFactoryWithAliases() override {} private: diff --git a/src/Functions/FunctionFactory.cpp b/src/Functions/FunctionFactory.cpp index 768f1cfe487..09fd360a925 100644 --- a/src/Functions/FunctionFactory.cpp +++ b/src/Functions/FunctionFactory.cpp @@ -21,6 +21,10 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; } +const String & getFunctionCanonicalNameIfAny(const String & name) +{ + return FunctionFactory::instance().getCanonicalNameIfAny(name); +} void FunctionFactory::registerFunction(const std::string & name, @@ -36,10 +40,13 @@ void FunctionFactory::registerFunction(const throw Exception("FunctionFactory: the function name '" + name + "' is already registered as alias", ErrorCodes::LOGICAL_ERROR); - if (case_sensitiveness == CaseInsensitive - && !case_insensitive_functions.emplace(function_name_lowercase, creator).second) - throw Exception("FunctionFactory: the case insensitive function name '" + name + "' is not unique", - ErrorCodes::LOGICAL_ERROR); + if (case_sensitiveness == CaseInsensitive) + { + if (!case_insensitive_functions.emplace(function_name_lowercase, creator).second) + throw Exception("FunctionFactory: the case insensitive function name '" + name + "' is not unique", + ErrorCodes::LOGICAL_ERROR); + case_insensitive_name_mapping[function_name_lowercase] = name; + } } diff --git a/src/Functions/FunctionsRound.cpp b/src/Functions/FunctionsRound.cpp index b1349bd2164..c5ad27a0b90 100644 --- a/src/Functions/FunctionsRound.cpp +++ b/src/Functions/FunctionsRound.cpp @@ -8,7 +8,7 @@ namespace DB void registerFunctionsRound(FunctionFactory & factory) { factory.registerFunction("round", FunctionFactory::CaseInsensitive); - factory.registerFunction("roundBankers", FunctionFactory::CaseInsensitive); + factory.registerFunction("roundBankers", FunctionFactory::CaseSensitive); factory.registerFunction("floor", FunctionFactory::CaseInsensitive); factory.registerFunction("ceil", FunctionFactory::CaseInsensitive); factory.registerFunction("trunc", FunctionFactory::CaseInsensitive); diff --git a/src/Functions/extractAllGroupsVertical.cpp b/src/Functions/extractAllGroupsVertical.cpp index 9cbd148b016..bf33eef70f3 100644 --- a/src/Functions/extractAllGroupsVertical.cpp +++ b/src/Functions/extractAllGroupsVertical.cpp @@ -18,7 +18,7 @@ namespace DB void registerFunctionExtractAllGroupsVertical(FunctionFactory & factory) { factory.registerFunction>(); - factory.registerAlias("extractAllGroups", VerticalImpl::Name, FunctionFactory::CaseInsensitive); + factory.registerAlias("extractAllGroups", VerticalImpl::Name, FunctionFactory::CaseSensitive); } } diff --git a/src/Interpreters/FunctionNameNormalizer.cpp b/src/Interpreters/FunctionNameNormalizer.cpp new file mode 100644 index 00000000000..f22f72b5e03 --- /dev/null +++ b/src/Interpreters/FunctionNameNormalizer.cpp @@ -0,0 +1,18 @@ +#include + +namespace DB +{ + +const String & getFunctionCanonicalNameIfAny(const String & name); +const String & getAggregateFunctionCanonicalNameIfAny(const String & name); + +void FunctionNameNormalizer::visit(ASTPtr & ast) +{ + if (auto * node_func = ast->as()) + node_func->name = getAggregateFunctionCanonicalNameIfAny(getFunctionCanonicalNameIfAny(node_func->name)); + + for (auto & child : ast->children) + visit(child); +} + +} diff --git a/src/Interpreters/FunctionNameNormalizer.h b/src/Interpreters/FunctionNameNormalizer.h new file mode 100644 index 00000000000..2b20c28bce0 --- /dev/null +++ b/src/Interpreters/FunctionNameNormalizer.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +namespace DB +{ + +struct FunctionNameNormalizer +{ + static void visit(ASTPtr &); +}; + +} diff --git a/src/Interpreters/MutationsInterpreter.cpp b/src/Interpreters/MutationsInterpreter.cpp index 528b5ec6d8e..c393b214ee8 100644 --- a/src/Interpreters/MutationsInterpreter.cpp +++ b/src/Interpreters/MutationsInterpreter.cpp @@ -442,10 +442,10 @@ ASTPtr MutationsInterpreter::prepare(bool dry_run) auto type_literal = std::make_shared(columns_desc.getPhysical(column).type->getName()); const auto & update_expr = kv.second; - auto updated_column = makeASTFunction("cast", + auto updated_column = makeASTFunction("CAST", makeASTFunction("if", getPartitionAndPredicateExpressionForMutationCommand(command), - makeASTFunction("cast", + makeASTFunction("CAST", update_expr->clone(), type_literal), std::make_shared(column)), diff --git a/src/Interpreters/TreeRewriter.cpp b/src/Interpreters/TreeRewriter.cpp index fd87d86bf97..cf4db8f174e 100644 --- a/src/Interpreters/TreeRewriter.cpp +++ b/src/Interpreters/TreeRewriter.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -934,6 +935,9 @@ void TreeRewriter::normalize(ASTPtr & query, Aliases & aliases, const Settings & MarkTableIdentifiersVisitor::Data identifiers_data{aliases}; MarkTableIdentifiersVisitor(identifiers_data).visit(query); + /// Rewrite function names to their canonical ones. + FunctionNameNormalizer().visit(query); + /// Common subexpression elimination. Rewrite rules. QueryNormalizer::Data normalizer_data(aliases, settings); QueryNormalizer(normalizer_data).visit(query); diff --git a/src/Interpreters/addTypeConversionToAST.cpp b/src/Interpreters/addTypeConversionToAST.cpp index bb42ad79daa..18591fd732c 100644 --- a/src/Interpreters/addTypeConversionToAST.cpp +++ b/src/Interpreters/addTypeConversionToAST.cpp @@ -20,7 +20,7 @@ namespace ErrorCodes ASTPtr addTypeConversionToAST(ASTPtr && ast, const String & type_name) { - auto func = makeASTFunction("cast", ast, std::make_shared(type_name)); + auto func = makeASTFunction("CAST", ast, std::make_shared(type_name)); if (ASTWithAlias * ast_with_alias = dynamic_cast(ast.get())) { diff --git a/src/Interpreters/inplaceBlockConversions.cpp b/src/Interpreters/inplaceBlockConversions.cpp index eba03d7aa61..c9a96a81b48 100644 --- a/src/Interpreters/inplaceBlockConversions.cpp +++ b/src/Interpreters/inplaceBlockConversions.cpp @@ -43,7 +43,7 @@ void addDefaultRequiredExpressionsRecursively(const Block & block, const String RequiredSourceColumnsVisitor(columns_context).visit(column_default_expr); NameSet required_columns_names = columns_context.requiredColumns(); - auto cast_func = makeASTFunction("cast", column_default_expr, std::make_shared(columns.get(required_column).type->getName())); + auto cast_func = makeASTFunction("CAST", column_default_expr, std::make_shared(columns.get(required_column).type->getName())); default_expr_list_accum->children.emplace_back(setAlias(cast_func, required_column)); added_columns.emplace(required_column); diff --git a/src/Processors/Formats/Impl/ConstantExpressionTemplate.cpp b/src/Processors/Formats/Impl/ConstantExpressionTemplate.cpp index d7a65c2f15d..1685688f02d 100644 --- a/src/Processors/Formats/Impl/ConstantExpressionTemplate.cpp +++ b/src/Processors/Formats/Impl/ConstantExpressionTemplate.cpp @@ -626,7 +626,7 @@ void ConstantExpressionTemplate::TemplateStructure::addNodesToCastResult(const I expr = makeASTFunction("assumeNotNull", std::move(expr)); } - expr = makeASTFunction("cast", std::move(expr), std::make_shared(result_column_type.getName())); + expr = makeASTFunction("CAST", std::move(expr), std::make_shared(result_column_type.getName())); if (null_as_default) { diff --git a/tests/integration/test_mysql_protocol/test.py b/tests/integration/test_mysql_protocol/test.py index 9532d4b8ba2..7f7d59674bc 100644 --- a/tests/integration/test_mysql_protocol/test.py +++ b/tests/integration/test_mysql_protocol/test.py @@ -217,7 +217,7 @@ def test_mysql_replacement_query(mysql_client, server_address): --password=123 -e "select database();" '''.format(host=server_address, port=server_port), demux=True) assert code == 0 - assert stdout.decode() == 'database()\ndefault\n' + assert stdout.decode() == 'DATABASE()\ndefault\n' code, (stdout, stderr) = mysql_client.exec_run(''' mysql --protocol tcp -h {host} -P {port} default -u default diff --git a/tests/queries/0_stateless/00597_push_down_predicate.reference b/tests/queries/0_stateless/00597_push_down_predicate.reference index 794d9e7af5f..bd1c4791df4 100644 --- a/tests/queries/0_stateless/00597_push_down_predicate.reference +++ b/tests/queries/0_stateless/00597_push_down_predicate.reference @@ -114,7 +114,7 @@ FROM ( SELECT 1 AS id, - identity(cast(1, \'UInt8\')) AS subquery + identity(CAST(1, \'UInt8\')) AS subquery WHERE subquery = 1 ) WHERE subquery = 1 diff --git a/tests/queries/0_stateless/01029_early_constant_folding.reference b/tests/queries/0_stateless/01029_early_constant_folding.reference index 8a1d4cec388..8a2d7e6c61a 100644 --- a/tests/queries/0_stateless/01029_early_constant_folding.reference +++ b/tests/queries/0_stateless/01029_early_constant_folding.reference @@ -2,7 +2,7 @@ SELECT 1 WHERE 0 SELECT 1 SELECT 1 -WHERE (1 IN (0, 2)) AND (2 = (identity(cast(2, \'UInt8\')) AS subquery)) +WHERE (1 IN (0, 2)) AND (2 = (identity(CAST(2, \'UInt8\')) AS subquery)) SELECT 1 WHERE 1 IN ( ( diff --git a/tests/queries/0_stateless/01611_constant_folding_subqueries.reference b/tests/queries/0_stateless/01611_constant_folding_subqueries.reference index d10502c5860..e46fd479413 100644 --- a/tests/queries/0_stateless/01611_constant_folding_subqueries.reference +++ b/tests/queries/0_stateless/01611_constant_folding_subqueries.reference @@ -5,7 +5,7 @@ SELECT (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n) FO 1,10 EXPLAIN SYNTAX SELECT (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n); SELECT - identity(cast(0, \'UInt64\')) AS n, + identity(CAST(0, \'UInt64\')) AS n, toUInt64(10 / n) SELECT * FROM (WITH (SELECT * FROM system.numbers LIMIT 1 OFFSET 1) AS n, toUInt64(10 / n) as q SELECT * FROM system.one WHERE q > 0); 0 diff --git a/tests/queries/0_stateless/01705_normalize_case_insensitive_function_names.reference b/tests/queries/0_stateless/01705_normalize_case_insensitive_function_names.reference new file mode 100644 index 00000000000..5b0f7bdeb2d --- /dev/null +++ b/tests/queries/0_stateless/01705_normalize_case_insensitive_function_names.reference @@ -0,0 +1,66 @@ +SELECT + CAST(1, 'INT'), + ceil(1), + ceil(1), + char(49), + CHAR_LENGTH('1'), + CHARACTER_LENGTH('1'), + coalesce(1), + concat('1', '1'), + corr(1, 1), + cos(1), + count(), + covarPop(1, 1), + covarSamp(1, 1), + DATABASE(), + dateDiff('DAY', toDate('2020-10-24'), toDate('2019-10-24')), + exp(1), + arrayFlatten([[1]]), + floor(1), + FQDN(), + greatest(1), + 1, + ifNull(1, 1), + lower('A'), + least(1), + length('1'), + log(1), + position('1', '1'), + log(1), + log10(1), + log2(1), + lower('A'), + max(1), + substring('123', 1, 1), + min(1), + 1 % 1, + NOT 1, + now(), + now64(), + nullIf(1, 1), + pi(), + position('123', '2'), + pow(1, 1), + pow(1, 1), + rand(), + replaceAll('1', '1', '2'), + reverse('123'), + round(1), + sin(1), + sqrt(1), + stddevPop(1), + stddevSamp(1), + substring('123', 2), + substring('123', 2), + count(), + tan(1), + tanh(1), + trunc(1), + trunc(1), + upper('A'), + upper('A'), + currentUser(), + varPop(1), + varSamp(1), + toWeek(toDate('2020-10-24')), + toYearWeek(toDate('2020-10-24')) diff --git a/tests/queries/0_stateless/01705_normalize_case_insensitive_function_names.sql b/tests/queries/0_stateless/01705_normalize_case_insensitive_function_names.sql new file mode 100644 index 00000000000..9b35087182c --- /dev/null +++ b/tests/queries/0_stateless/01705_normalize_case_insensitive_function_names.sql @@ -0,0 +1 @@ +EXPLAIN SYNTAX SELECT CAST(1 AS INT), CEIL(1), CEILING(1), CHAR(49), CHAR_LENGTH('1'), CHARACTER_LENGTH('1'), COALESCE(1), CONCAT('1', '1'), CORR(1, 1), COS(1), COUNT(1), COVAR_POP(1, 1), COVAR_SAMP(1, 1), DATABASE(), DATEDIFF('DAY', toDate('2020-10-24'), toDate('2019-10-24')), EXP(1), FLATTEN([[1]]), FLOOR(1), FQDN(), GREATEST(1), IF(1, 1, 1), IFNULL(1, 1), LCASE('A'), LEAST(1), LENGTH('1'), LN(1), LOCATE('1', '1'), LOG(1), LOG10(1), LOG2(1), LOWER('A'), MAX(1), MID('123', 1, 1), MIN(1), MOD(1, 1), NOT(1), NOW(), NOW64(), NULLIF(1, 1), PI(), POSITION('123', '2'), POW(1, 1), POWER(1, 1), RAND(), REPLACE('1', '1', '2'), REVERSE('123'), ROUND(1), SIN(1), SQRT(1), STDDEV_POP(1), STDDEV_SAMP(1), SUBSTR('123', 2), SUBSTRING('123', 2), SUM(1), TAN(1), TANH(1), TRUNC(1), TRUNCATE(1), UCASE('A'), UPPER('A'), USER(), VAR_POP(1), VAR_SAMP(1), WEEK(toDate('2020-10-24')), YEARWEEK(toDate('2020-10-24')) format TSVRaw;