diff --git a/dbms/include/DB/Interpreters/evaluateDatabaseName.h b/dbms/include/DB/Interpreters/evaluateDatabaseName.h deleted file mode 100644 index 8ed967f9bce..00000000000 --- a/dbms/include/DB/Interpreters/evaluateDatabaseName.h +++ /dev/null @@ -1,42 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - - -namespace DB -{ - namespace - { - ASTPtr makeIdentifier(const ASTPtr & expr, const Context & context) - { - /// for identifier just return its name - if (typeid_cast(expr.get())) - return expr; - - /// for string literal return its value - if (const auto literal = typeid_cast(expr.get())) - return new ASTIdentifier{{}, safeGet(literal->value)}; - - /// otherwise evaluate expression and ensure it has string type - Block block{}; - ExpressionAnalyzer{expr, context, { { "", new DataTypeString } }}.getActions(false)->execute(block); - - const auto & column_name_type = block.getByName(expr->getColumnName()); - - if (!typeid_cast(column_name_type.type.get())) - throw Exception{""}; - - return new ASTIdentifier{{}, column_name_type.column->getDataAt(0).toString()}; - } - - String evaluateDatabaseName(ASTPtr & expr, const Context & context) - { - expr = makeIdentifier(expr, context); - return static_cast(expr.get())->name; - } - } -} diff --git a/dbms/include/DB/Interpreters/reinterpretAsIdentifier.h b/dbms/include/DB/Interpreters/reinterpretAsIdentifier.h new file mode 100644 index 00000000000..93315138bea --- /dev/null +++ b/dbms/include/DB/Interpreters/reinterpretAsIdentifier.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include +#include +#include + + +namespace DB +{ + namespace + { + ASTPtr reinterpretAsIdentifierImpl(const ASTPtr & expr, const Context & context) + { + /// for string literal return its value + if (const auto literal = typeid_cast(expr.get())) + return new ASTIdentifier{{}, safeGet(literal->value)}; + + /// otherwise evaluate the expression + Block block{}; + /** pass a dummy column name because ExpressioAnalyzer + * does not work with no columns so far. */ + ExpressionAnalyzer{ + expr, context, + { { "", new DataTypeString } } + }.getActions(false)->execute(block); + + const auto & column_name_type = block.getByName(expr->getColumnName()); + + /// ensure the result of evaluation has String type + if (!typeid_cast(column_name_type.type.get())) + throw Exception{"Expression must evaluate to a String"}; + + return new ASTIdentifier{{}, column_name_type.column->getDataAt(0).toString()}; + } + } + + /** \brief if `expr` is not already ASTIdentifier evaluates it + * and replaces by a new ASTIdentifier with the result of evaluation as its name. + * `expr` must evaluate to a String type */ + inline ASTIdentifier & reinterpretAsIdentifier(ASTPtr & expr, const Context & context) + { + /// for identifier just return its name + if (!typeid_cast(expr.get())) + expr = reinterpretAsIdentifierImpl(expr, context); + + return static_cast(*expr); + } +} diff --git a/dbms/include/DB/TableFunctions/TableFunctionMerge.h b/dbms/include/DB/TableFunctions/TableFunctionMerge.h index cec8dcd632f..467604577f1 100644 --- a/dbms/include/DB/TableFunctions/TableFunctionMerge.h +++ b/dbms/include/DB/TableFunctions/TableFunctionMerge.h @@ -8,7 +8,7 @@ #include #include #include -#include +#include namespace DB @@ -41,7 +41,7 @@ public: " - name of source database and regexp for table names.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - String source_database = evaluateDatabaseName(args[0], context); + String source_database = reinterpretAsIdentifier(args[0], context).name; String table_name_regexp = safeGet(typeid_cast(*args[1]).value); /// В InterpreterSelectQuery будет создан ExpressionAnalzyer, который при обработке запроса наткнется на этот Identifier. diff --git a/dbms/include/DB/TableFunctions/TableFunctionRemote.h b/dbms/include/DB/TableFunctions/TableFunctionRemote.h index 6912a7baf84..3cc19da6020 100644 --- a/dbms/include/DB/TableFunctions/TableFunctionRemote.h +++ b/dbms/include/DB/TableFunctions/TableFunctionRemote.h @@ -4,7 +4,7 @@ #include #include #include -#include +#include struct data; @@ -44,7 +44,7 @@ public: throw Exception(err, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); String descripton = safeGet(typeid_cast(*args[0]).value); - String remote_database = evaluateDatabaseName(args[1], context); + String remote_database = reinterpretAsIdentifier(args[1], context).name; String remote_table = args.size() % 2 ? typeid_cast(*args[2]).name : ""; String username = args.size() >= 4 ? safeGet(typeid_cast(*args[args.size() - 2]).value) : "default"; diff --git a/dbms/src/Storages/StorageFactory.cpp b/dbms/src/Storages/StorageFactory.cpp index c551bbf1b14..49c5625b871 100644 --- a/dbms/src/Storages/StorageFactory.cpp +++ b/dbms/src/Storages/StorageFactory.cpp @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include @@ -108,7 +108,7 @@ StoragePtr StorageFactory::get( if (args.size() < 3 || args.size() > 4) break; - String source_database = evaluateDatabaseName(args[0], local_context); + String source_database = reinterpretAsIdentifier(args[0], local_context).name; String source_table_name_regexp = safeGet(typeid_cast(*args[1]).value); size_t chunks_to_merge = safeGet(typeid_cast(*args[2]).value); @@ -156,7 +156,7 @@ StoragePtr StorageFactory::get( " - name of source database and regexp for table names.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - String source_database = evaluateDatabaseName(args[0], local_context); + String source_database = reinterpretAsIdentifier(args[0], local_context).name; String table_name_regexp = safeGet(typeid_cast(*args[1]).value); return StorageMerge::create(table_name, columns, source_database, table_name_regexp, context); @@ -182,7 +182,7 @@ StoragePtr StorageFactory::get( ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); String cluster_name = typeid_cast(*args[0]).name; - String remote_database = evaluateDatabaseName(args[1], local_context); + String remote_database = reinterpretAsIdentifier(args[1], local_context).name; String remote_table = typeid_cast(*args[2]).name; const auto & sharding_key = args.size() == 4 ? args[3] : nullptr;