diff --git a/dbms/programs/client/Client.cpp b/dbms/programs/client/Client.cpp index 6b4a0c6eb58..724bbc9eb93 100644 --- a/dbms/programs/client/Client.cpp +++ b/dbms/programs/client/Client.cpp @@ -704,7 +704,7 @@ private: return true; } - ASTInsertQuery * insert = typeid_cast(ast.get()); + auto * insert = ast->as(); if (insert && insert->data) { @@ -799,14 +799,11 @@ private: written_progress_chars = 0; written_first_block = false; - const ASTSetQuery * set_query = typeid_cast(&*parsed_query); - const ASTUseQuery * use_query = typeid_cast(&*parsed_query); - /// INSERT query for which data transfer is needed (not an INSERT SELECT) is processed separately. - const ASTInsertQuery * insert = typeid_cast(&*parsed_query); - connection->forceConnected(); - if (insert && !insert->select) + /// INSERT query for which data transfer is needed (not an INSERT SELECT) is processed separately. + const auto * insert_query = parsed_query->as(); + if (insert_query && !insert_query->select) processInsertQuery(); else processOrdinaryQuery(); @@ -814,7 +811,7 @@ private: /// Do not change context (current DB, settings) in case of an exception. if (!got_exception) { - if (set_query) + if (const auto * set_query = parsed_query->as()) { /// Save all changes in settings to avoid losing them if the connection is lost. for (const auto & change : set_query->changes) @@ -826,7 +823,7 @@ private: } } - if (use_query) + if (const auto * use_query = parsed_query->as()) { const String & new_database = use_query->database; /// If the client initiates the reconnection, it takes the settings from the config. @@ -858,7 +855,7 @@ private: /// Convert external tables to ExternalTableData and send them using the connection. void sendExternalTables() { - auto * select = typeid_cast(&*parsed_query); + const auto * select = parsed_query->as(); if (!select && !external_tables.empty()) throw Exception("External tables could be sent only with select query", ErrorCodes::BAD_ARGUMENTS); @@ -883,7 +880,7 @@ private: void processInsertQuery() { /// Send part of query without data, because data will be sent separately. - const ASTInsertQuery & parsed_insert_query = typeid_cast(*parsed_query); + const auto & parsed_insert_query = parsed_query->as(); String query_without_data = parsed_insert_query.data ? query.substr(0, parsed_insert_query.data - query.data()) : query; @@ -940,7 +937,7 @@ private: void sendData(Block & sample, const ColumnsDescription & columns_description) { /// If INSERT data must be sent. - const ASTInsertQuery * parsed_insert_query = typeid_cast(&*parsed_query); + const auto * parsed_insert_query = parsed_query->as(); if (!parsed_insert_query) return; @@ -965,7 +962,7 @@ private: String current_format = insert_format; /// Data format can be specified in the INSERT query. - if (ASTInsertQuery * insert = typeid_cast(&*parsed_query)) + if (const auto * insert = parsed_query->as()) { if (!insert->format.empty()) current_format = insert->format; @@ -1231,12 +1228,14 @@ private: String current_format = format; /// The query can specify output format or output file. - if (ASTQueryWithOutput * query_with_output = dynamic_cast(&*parsed_query)) + /// FIXME: try to prettify this cast using `as<>()` + if (const auto * query_with_output = dynamic_cast(parsed_query.get())) { - if (query_with_output->out_file != nullptr) + if (query_with_output->out_file) { - const auto & out_file_node = typeid_cast(*query_with_output->out_file); + const auto & out_file_node = query_with_output->out_file->as(); const auto & out_file = out_file_node.value.safeGet(); + out_file_buf.emplace(out_file, DBMS_DEFAULT_BUFFER_SIZE, O_WRONLY | O_EXCL | O_CREAT); out_buf = &*out_file_buf; @@ -1248,7 +1247,7 @@ private: { if (has_vertical_output_suffix) throw Exception("Output format already specified", ErrorCodes::CLIENT_OUTPUT_FORMAT_SPECIFIED); - const auto & id = typeid_cast(*query_with_output->format); + const auto & id = query_with_output->format->as(); current_format = id.name; } if (query_with_output->settings_ast) diff --git a/dbms/programs/copier/ClusterCopier.cpp b/dbms/programs/copier/ClusterCopier.cpp index 588c9984f58..451df591bbd 100644 --- a/dbms/programs/copier/ClusterCopier.cpp +++ b/dbms/programs/copier/ClusterCopier.cpp @@ -483,7 +483,7 @@ String DB::TaskShard::getHostNameExample() const static bool isExtendedDefinitionStorage(const ASTPtr & storage_ast) { - const ASTStorage & storage = typeid_cast(*storage_ast); + const auto & storage = storage_ast->as(); return storage.partition_by || storage.order_by || storage.sample_by; } @@ -491,8 +491,8 @@ static ASTPtr extractPartitionKey(const ASTPtr & storage_ast) { String storage_str = queryToString(storage_ast); - const ASTStorage & storage = typeid_cast(*storage_ast); - const ASTFunction & engine = typeid_cast(*storage.engine); + const auto & storage = storage_ast->as(); + const auto & engine = storage.engine->as(); if (!endsWith(engine.name, "MergeTree")) { @@ -501,7 +501,7 @@ static ASTPtr extractPartitionKey(const ASTPtr & storage_ast) } ASTPtr arguments_ast = engine.arguments->clone(); - ASTs & arguments = typeid_cast(*arguments_ast).children; + ASTs & arguments = arguments_ast->children; if (isExtendedDefinitionStorage(storage_ast)) { @@ -1179,12 +1179,12 @@ protected: /// Removes MATERIALIZED and ALIAS columns from create table query static ASTPtr removeAliasColumnsFromCreateQuery(const ASTPtr & query_ast) { - const ASTs & column_asts = typeid_cast(*query_ast).columns_list->columns->children; + const ASTs & column_asts = query_ast->as().columns_list->columns->children; auto new_columns = std::make_shared(); for (const ASTPtr & column_ast : column_asts) { - const ASTColumnDeclaration & column = typeid_cast(*column_ast); + const auto & column = column_ast->as(); if (!column.default_specifier.empty()) { @@ -1197,12 +1197,11 @@ protected: } ASTPtr new_query_ast = query_ast->clone(); - ASTCreateQuery & new_query = typeid_cast(*new_query_ast); + auto & new_query = new_query_ast->as(); auto new_columns_list = std::make_shared(); new_columns_list->set(new_columns_list->columns, new_columns); - new_columns_list->set( - new_columns_list->indices, typeid_cast(*query_ast).columns_list->indices->clone()); + new_columns_list->set(new_columns_list->indices, query_ast->as()->columns_list->indices->clone()); new_query.replace(new_query.columns_list, new_columns_list); @@ -1212,7 +1211,7 @@ protected: /// Replaces ENGINE and table name in a create query std::shared_ptr rewriteCreateQueryStorage(const ASTPtr & create_query_ast, const DatabaseAndTableName & new_table, const ASTPtr & new_storage_ast) { - ASTCreateQuery & create = typeid_cast(*create_query_ast); + const auto & create = create_query_ast->as(); auto res = std::make_shared(create); if (create.storage == nullptr || new_storage_ast == nullptr) @@ -1646,7 +1645,7 @@ protected: /// Try create table (if not exists) on each shard { auto create_query_push_ast = rewriteCreateQueryStorage(task_shard.current_pull_table_create_query, task_table.table_push, task_table.engine_push_ast); - typeid_cast(*create_query_push_ast).if_not_exists = true; + create_query_push_ast->as().if_not_exists = true; String query = queryToString(create_query_push_ast); LOG_DEBUG(log, "Create destination tables. Query: " << query); @@ -1779,7 +1778,7 @@ protected: void dropAndCreateLocalTable(const ASTPtr & create_ast) { - auto & create = typeid_cast(*create_ast); + const auto & create = create_ast->as(); dropLocalTableIfExists({create.database, create.table}); InterpreterCreateQuery interpreter(create_ast, context); diff --git a/dbms/src/AggregateFunctions/parseAggregateFunctionParameters.cpp b/dbms/src/AggregateFunctions/parseAggregateFunctionParameters.cpp index 5e5738592f7..bcb73f1e9d9 100644 --- a/dbms/src/AggregateFunctions/parseAggregateFunctionParameters.cpp +++ b/dbms/src/AggregateFunctions/parseAggregateFunctionParameters.cpp @@ -15,7 +15,7 @@ namespace ErrorCodes Array getAggregateFunctionParametersArray(const ASTPtr & expression_list, const std::string & error_context) { - const ASTs & parameters = typeid_cast(*expression_list).children; + const ASTs & parameters = expression_list->children; if (parameters.empty()) throw Exception("Parameters list to aggregate functions cannot be empty", ErrorCodes::BAD_ARGUMENTS); @@ -23,14 +23,14 @@ Array getAggregateFunctionParametersArray(const ASTPtr & expression_list, const for (size_t i = 0; i < parameters.size(); ++i) { - const ASTLiteral * lit = typeid_cast(parameters[i].get()); - if (!lit) + const auto * literal = parameters[i]->as(); + if (!literal) { throw Exception("Parameters to aggregate functions must be literals" + (error_context.empty() ? "" : " (in " + error_context +")"), ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS); } - params_row[i] = lit->value; + params_row[i] = literal->value; } return params_row; @@ -67,8 +67,7 @@ void getAggregateFunctionNameAndParametersArray( parameters_str.data(), parameters_str.data() + parameters_str.size(), "parameters of aggregate function in " + error_context, 0); - ASTExpressionList & args_list = typeid_cast(*args_ast); - if (args_list.children.empty()) + if (args_ast->children.empty()) throw Exception("Incorrect list of parameters to aggregate function " + aggregate_function_name, ErrorCodes::BAD_ARGUMENTS); diff --git a/dbms/src/Common/TypePromotion.h b/dbms/src/Common/TypePromotion.h new file mode 100644 index 00000000000..18ac0821b2c --- /dev/null +++ b/dbms/src/Common/TypePromotion.h @@ -0,0 +1,63 @@ +#pragma once + +#include + +namespace DB +{ + +/* This base class adds public methods: + * - Derived * as() + * - const Derived * as() const + * - Derived & as() + * - const Derived & as() const + */ + +template +class TypePromotion +{ +private: + /// Need a helper-struct to fight the lack of the function-template partial specialization. + template > + struct CastHelper; + + template + struct CastHelper + { + auto & value(Base * ptr) { return typeid_cast(*ptr); } + }; + + template + struct CastHelper + { + auto & value(const Base * ptr) { return typeid_cast>>>(*ptr); } + }; + + template + struct CastHelper + { + auto * value(Base * ptr) { return typeid_cast(ptr); } + }; + + template + struct CastHelper + { + auto * value(const Base * ptr) { return typeid_cast *>(ptr); } + }; + +public: + template + auto as() -> std::invoke_result_t::value), CastHelper, Base *> + { + // TODO: if we do downcast to base type, then just return |this|. + return CastHelper().value(static_cast(this)); + } + + template + auto as() const -> std::invoke_result_t::value), CastHelper, const Base *> + { + // TODO: if we do downcast to base type, then just return |this|. + return CastHelper().value(static_cast(this)); + } +}; + +} // namespace DB diff --git a/dbms/src/Common/typeid_cast.h b/dbms/src/Common/typeid_cast.h index 99faeb40742..9285355e788 100644 --- a/dbms/src/Common/typeid_cast.h +++ b/dbms/src/Common/typeid_cast.h @@ -25,18 +25,32 @@ namespace DB template std::enable_if_t, To> typeid_cast(From & from) { - if (typeid(from) == typeid(To)) - return static_cast(from); - else - throw DB::Exception("Bad cast from type " + demangle(typeid(from).name()) + " to " + demangle(typeid(To).name()), - DB::ErrorCodes::BAD_CAST); + try + { + if (typeid(from) == typeid(To)) + return static_cast(from); + } + catch (const std::exception & e) + { + throw DB::Exception(e.what(), DB::ErrorCodes::BAD_CAST); + } + + throw DB::Exception("Bad cast from type " + demangle(typeid(from).name()) + " to " + demangle(typeid(To).name()), + DB::ErrorCodes::BAD_CAST); } template To typeid_cast(From * from) { - if (typeid(*from) == typeid(std::remove_pointer_t)) - return static_cast(from); - else - return nullptr; + try + { + if (typeid(*from) == typeid(std::remove_pointer_t)) + return static_cast(from); + else + return nullptr; + } + catch (const std::exception & e) + { + throw DB::Exception(e.what(), DB::ErrorCodes::BAD_CAST); + } } diff --git a/dbms/src/Compression/CompressionCodecDelta.cpp b/dbms/src/Compression/CompressionCodecDelta.cpp index 2c9eba1c558..08cc37864dd 100644 --- a/dbms/src/Compression/CompressionCodecDelta.cpp +++ b/dbms/src/Compression/CompressionCodecDelta.cpp @@ -144,7 +144,7 @@ void registerCodecDelta(CompressionCodecFactory & factory) throw Exception("Delta codec must have 1 parameter, given " + std::to_string(arguments->children.size()), ErrorCodes::ILLEGAL_SYNTAX_FOR_CODEC_TYPE); const auto children = arguments->children; - const ASTLiteral * literal = static_cast(children[0].get()); + const auto * literal = children[0]->as(); size_t user_bytes_size = literal->value.safeGet(); if (user_bytes_size != 1 && user_bytes_size != 2 && user_bytes_size != 4 && user_bytes_size != 8) throw Exception("Delta value for delta codec can be 1, 2, 4 or 8, given " + toString(user_bytes_size), ErrorCodes::ILLEGAL_CODEC_PARAMETER); diff --git a/dbms/src/Compression/CompressionCodecLZ4.cpp b/dbms/src/Compression/CompressionCodecLZ4.cpp index c0d0fa99e1b..08553e0920c 100644 --- a/dbms/src/Compression/CompressionCodecLZ4.cpp +++ b/dbms/src/Compression/CompressionCodecLZ4.cpp @@ -86,7 +86,7 @@ void registerCodecLZ4HC(CompressionCodecFactory & factory) throw Exception("LZ4HC codec must have 1 parameter, given " + std::to_string(arguments->children.size()), ErrorCodes::ILLEGAL_SYNTAX_FOR_CODEC_TYPE); const auto children = arguments->children; - const ASTLiteral * literal = static_cast(children[0].get()); + const auto * literal = children[0]->as(); level = literal->value.safeGet(); } @@ -100,4 +100,3 @@ CompressionCodecLZ4HC::CompressionCodecLZ4HC(int level_) } } - diff --git a/dbms/src/Compression/CompressionCodecZSTD.cpp b/dbms/src/Compression/CompressionCodecZSTD.cpp index ac7e24ff9ba..9e7e66a5e6e 100644 --- a/dbms/src/Compression/CompressionCodecZSTD.cpp +++ b/dbms/src/Compression/CompressionCodecZSTD.cpp @@ -73,7 +73,7 @@ void registerCodecZSTD(CompressionCodecFactory & factory) throw Exception("ZSTD codec must have 1 parameter, given " + std::to_string(arguments->children.size()), ErrorCodes::ILLEGAL_SYNTAX_FOR_CODEC_TYPE); const auto children = arguments->children; - const ASTLiteral * literal = static_cast(children[0].get()); + const auto * literal = children[0]->as(); level = literal->value.safeGet(); if (level > ZSTD_maxCLevel()) throw Exception("ZSTD codec can't have level more that " + toString(ZSTD_maxCLevel()) + ", given " + toString(level), ErrorCodes::ILLEGAL_CODEC_PARAMETER); diff --git a/dbms/src/Compression/CompressionFactory.cpp b/dbms/src/Compression/CompressionFactory.cpp index b5b2bfe6b5e..ed34b8817d3 100644 --- a/dbms/src/Compression/CompressionFactory.cpp +++ b/dbms/src/Compression/CompressionFactory.cpp @@ -56,15 +56,15 @@ CompressionCodecPtr CompressionCodecFactory::get(const std::vector(ast.get())) + if (const auto * func = ast->as()) { Codecs codecs; codecs.reserve(func->arguments->children.size()); for (const auto & inner_codec_ast : func->arguments->children) { - if (const auto * family_name = typeid_cast(inner_codec_ast.get())) + if (const auto * family_name = inner_codec_ast->as()) codecs.emplace_back(getImpl(family_name->name, {}, column_type)); - else if (const auto * ast_func = typeid_cast(inner_codec_ast.get())) + else if (const auto * ast_func = inner_codec_ast->as()) codecs.emplace_back(getImpl(ast_func->name, ast_func->arguments, column_type)); else throw Exception("Unexpected AST element for compression codec", ErrorCodes::UNEXPECTED_AST_STRUCTURE); diff --git a/dbms/src/Compression/CompressionFactory.h b/dbms/src/Compression/CompressionFactory.h index db28a192011..b36bed1cf8e 100644 --- a/dbms/src/Compression/CompressionFactory.h +++ b/dbms/src/Compression/CompressionFactory.h @@ -1,14 +1,17 @@ #pragma once -#include +#include +#include +#include +#include +#include + +#include + #include +#include #include #include -#include -#include -#include -#include -#include namespace DB { @@ -19,10 +22,6 @@ using CompressionCodecPtr = std::shared_ptr; using CodecNameWithLevel = std::pair>; -class IAST; - -using ASTPtr = std::shared_ptr; - /** Creates a codec object by name of compression algorithm family and parameters. */ class CompressionCodecFactory final : public ext::singleton diff --git a/dbms/src/DataStreams/InputStreamFromASTInsertQuery.cpp b/dbms/src/DataStreams/InputStreamFromASTInsertQuery.cpp index f9da736d4c4..dc039648053 100644 --- a/dbms/src/DataStreams/InputStreamFromASTInsertQuery.cpp +++ b/dbms/src/DataStreams/InputStreamFromASTInsertQuery.cpp @@ -20,7 +20,7 @@ namespace ErrorCodes InputStreamFromASTInsertQuery::InputStreamFromASTInsertQuery( const ASTPtr & ast, ReadBuffer * input_buffer_tail_part, const Block & header, Context & context) { - const ASTInsertQuery * ast_insert_query = dynamic_cast(ast.get()); + const auto * ast_insert_query = ast->as(); if (!ast_insert_query) throw Exception("Logical error: query requires data to insert, but it is not INSERT query", ErrorCodes::LOGICAL_ERROR); diff --git a/dbms/src/DataTypes/DataTypeAggregateFunction.cpp b/dbms/src/DataTypes/DataTypeAggregateFunction.cpp index 25ed8836538..a5dd5f8be62 100644 --- a/dbms/src/DataTypes/DataTypeAggregateFunction.cpp +++ b/dbms/src/DataTypes/DataTypeAggregateFunction.cpp @@ -340,30 +340,30 @@ static DataTypePtr create(const ASTPtr & arguments) throw Exception("Data type AggregateFunction requires parameters: " "name of aggregate function and list of data types for arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - if (const ASTFunction * parametric = typeid_cast(arguments->children[0].get())) + if (const auto * parametric = arguments->children[0]->as()) { if (parametric->parameters) throw Exception("Unexpected level of parameters to aggregate function", ErrorCodes::SYNTAX_ERROR); function_name = parametric->name; - const ASTs & parameters = typeid_cast(*parametric->arguments).children; + const ASTs & parameters = parametric->arguments->children; params_row.resize(parameters.size()); for (size_t i = 0; i < parameters.size(); ++i) { - const ASTLiteral * lit = typeid_cast(parameters[i].get()); - if (!lit) + const auto * literal = parameters[i]->as(); + if (!literal) throw Exception("Parameters to aggregate functions must be literals", ErrorCodes::PARAMETERS_TO_AGGREGATE_FUNCTIONS_MUST_BE_LITERALS); - params_row[i] = lit->value; + params_row[i] = literal->value; } } else if (auto opt_name = getIdentifierName(arguments->children[0])) { function_name = *opt_name; } - else if (typeid_cast(arguments->children[0].get())) + else if (arguments->children[0]->as()) { throw Exception("Aggregate function name for data type AggregateFunction must be passed as identifier (without quotes) or function", ErrorCodes::BAD_ARGUMENTS); @@ -389,4 +389,3 @@ void registerDataTypeAggregateFunction(DataTypeFactory & factory) } - diff --git a/dbms/src/DataTypes/DataTypeDateTime.cpp b/dbms/src/DataTypes/DataTypeDateTime.cpp index 25b2d966e6b..f3d6efa1488 100644 --- a/dbms/src/DataTypes/DataTypeDateTime.cpp +++ b/dbms/src/DataTypes/DataTypeDateTime.cpp @@ -186,7 +186,7 @@ static DataTypePtr create(const ASTPtr & arguments) if (arguments->children.size() != 1) throw Exception("DateTime data type can optionally have only one argument - time zone name", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - const ASTLiteral * arg = typeid_cast(arguments->children[0].get()); + const auto * arg = arguments->children[0]->as(); if (!arg || arg->value.getType() != Field::Types::String) throw Exception("Parameter for DateTime data type must be string literal", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); diff --git a/dbms/src/DataTypes/DataTypeEnum.cpp b/dbms/src/DataTypes/DataTypeEnum.cpp index bd93105a288..24f760a1800 100644 --- a/dbms/src/DataTypes/DataTypeEnum.cpp +++ b/dbms/src/DataTypes/DataTypeEnum.cpp @@ -357,7 +357,7 @@ static DataTypePtr create(const ASTPtr & arguments) /// Children must be functions 'equals' with string literal as left argument and numeric literal as right argument. for (const ASTPtr & child : arguments->children) { - const ASTFunction * func = typeid_cast(child.get()); + const auto * func = child->as(); if (!func || func->name != "equals" || func->parameters @@ -366,8 +366,8 @@ static DataTypePtr create(const ASTPtr & arguments) throw Exception("Elements of Enum data type must be of form: 'name' = number, where name is string literal and number is an integer", ErrorCodes::UNEXPECTED_AST_STRUCTURE); - const ASTLiteral * name_literal = typeid_cast(func->arguments->children[0].get()); - const ASTLiteral * value_literal = typeid_cast(func->arguments->children[1].get()); + const auto * name_literal = func->arguments->children[0]->as(); + const auto * value_literal = func->arguments->children[1]->as(); if (!name_literal || !value_literal diff --git a/dbms/src/DataTypes/DataTypeFactory.cpp b/dbms/src/DataTypes/DataTypeFactory.cpp index 85df6bee260..a0afab890e9 100644 --- a/dbms/src/DataTypes/DataTypeFactory.cpp +++ b/dbms/src/DataTypes/DataTypeFactory.cpp @@ -32,19 +32,19 @@ DataTypePtr DataTypeFactory::get(const String & full_name) const DataTypePtr DataTypeFactory::get(const ASTPtr & ast) const { - if (const ASTFunction * func = typeid_cast(ast.get())) + if (const auto * func = ast->as()) { if (func->parameters) throw Exception("Data type cannot have multiple parenthesed parameters.", ErrorCodes::ILLEGAL_SYNTAX_FOR_DATA_TYPE); return get(func->name, func->arguments); } - if (const ASTIdentifier * ident = typeid_cast(ast.get())) + if (const auto * ident = ast->as()) { return get(ident->name, {}); } - if (const ASTLiteral * lit = typeid_cast(ast.get())) + if (const auto * lit = ast->as()) { if (lit->value.isNull()) return get("Null", {}); diff --git a/dbms/src/DataTypes/DataTypeFactory.h b/dbms/src/DataTypes/DataTypeFactory.h index 95cda9002f4..c6ef100bbb7 100644 --- a/dbms/src/DataTypes/DataTypeFactory.h +++ b/dbms/src/DataTypes/DataTypeFactory.h @@ -1,12 +1,15 @@ #pragma once -#include -#include -#include -#include #include +#include +#include + #include +#include +#include +#include + namespace DB { @@ -17,9 +20,6 @@ using DataTypePtr = std::shared_ptr; class IDataTypeDomain; using DataTypeDomainPtr = std::unique_ptr; -class IAST; -using ASTPtr = std::shared_ptr; - /** Creates a data type by name of data type family and parameters. */ diff --git a/dbms/src/DataTypes/DataTypeFixedString.cpp b/dbms/src/DataTypes/DataTypeFixedString.cpp index 64c94602f8c..d1a007e16d2 100644 --- a/dbms/src/DataTypes/DataTypeFixedString.cpp +++ b/dbms/src/DataTypes/DataTypeFixedString.cpp @@ -273,7 +273,7 @@ static DataTypePtr create(const ASTPtr & arguments) if (!arguments || arguments->children.size() != 1) throw Exception("FixedString data type family must have exactly one argument - size in bytes", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - const ASTLiteral * argument = typeid_cast(arguments->children[0].get()); + const auto * argument = arguments->children[0]->as(); if (!argument || argument->value.getType() != Field::Types::UInt64 || argument->value.get() == 0) throw Exception("FixedString data type family must have a number (positive integer) as its argument", ErrorCodes::UNEXPECTED_AST_STRUCTURE); diff --git a/dbms/src/DataTypes/DataTypeTuple.cpp b/dbms/src/DataTypes/DataTypeTuple.cpp index a799662429f..8f52b5fd3ff 100644 --- a/dbms/src/DataTypes/DataTypeTuple.cpp +++ b/dbms/src/DataTypes/DataTypeTuple.cpp @@ -531,7 +531,7 @@ static DataTypePtr create(const ASTPtr & arguments) for (const ASTPtr & child : arguments->children) { - if (const ASTNameTypePair * name_and_type_pair = typeid_cast(child.get())) + if (const auto * name_and_type_pair = child->as()) { nested_types.emplace_back(DataTypeFactory::instance().get(name_and_type_pair->type)); names.emplace_back(name_and_type_pair->name); diff --git a/dbms/src/DataTypes/DataTypesDecimal.cpp b/dbms/src/DataTypes/DataTypesDecimal.cpp index a89196b0f0a..8ec5bb6664f 100644 --- a/dbms/src/DataTypes/DataTypesDecimal.cpp +++ b/dbms/src/DataTypes/DataTypesDecimal.cpp @@ -208,8 +208,8 @@ static DataTypePtr create(const ASTPtr & arguments) throw Exception("Decimal data type family must have exactly two arguments: precision and scale", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - const ASTLiteral * precision = typeid_cast(arguments->children[0].get()); - const ASTLiteral * scale = typeid_cast(arguments->children[1].get()); + const auto * precision = arguments->children[0]->as(); + const auto * scale = arguments->children[1]->as(); if (!precision || precision->value.getType() != Field::Types::UInt64 || !scale || !(scale->value.getType() == Field::Types::Int64 || scale->value.getType() == Field::Types::UInt64)) @@ -228,7 +228,7 @@ static DataTypePtr createExect(const ASTPtr & arguments) throw Exception("Decimal data type family must have exactly two arguments: precision and scale", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - const ASTLiteral * scale_arg = typeid_cast(arguments->children[0].get()); + const auto * scale_arg = arguments->children[0]->as(); if (!scale_arg || !(scale_arg->value.getType() == Field::Types::Int64 || scale_arg->value.getType() == Field::Types::UInt64)) throw Exception("Decimal data type family must have a two numbers as its arguments", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); diff --git a/dbms/src/Databases/DatabaseOrdinary.cpp b/dbms/src/Databases/DatabaseOrdinary.cpp index 78926146169..e144ed9071e 100644 --- a/dbms/src/Databases/DatabaseOrdinary.cpp +++ b/dbms/src/Databases/DatabaseOrdinary.cpp @@ -370,7 +370,7 @@ static ASTPtr getCreateQueryFromMetadata(const String & metadata_path, const Str if (ast) { - ASTCreateQuery & ast_create_query = typeid_cast(*ast); + auto & ast_create_query = ast->as(); ast_create_query.attach = false; ast_create_query.database = database; } @@ -415,8 +415,7 @@ void DatabaseOrdinary::renameTable( ASTPtr ast = getQueryFromMetadata(detail::getTableMetadataPath(metadata_path, table_name)); if (!ast) throw Exception("There is no metadata file for table " + table_name, ErrorCodes::FILE_DOESNT_EXIST); - ASTCreateQuery & ast_create_query = typeid_cast(*ast); - ast_create_query.table = to_table_name; + ast->as().table = to_table_name; /// NOTE Non-atomic. to_database_concrete->createTable(context, to_table_name, table, ast); @@ -534,7 +533,7 @@ void DatabaseOrdinary::alterTable( ParserCreateQuery parser; ASTPtr ast = parseQuery(parser, statement.data(), statement.data() + statement.size(), "in file " + table_metadata_path, 0); - ASTCreateQuery & ast_create_query = typeid_cast(*ast); + const auto & ast_create_query = ast->as(); ASTPtr new_columns = InterpreterCreateQuery::formatColumns(columns); ASTPtr new_indices = InterpreterCreateQuery::formatIndices(indices); diff --git a/dbms/src/Databases/DatabasesCommon.cpp b/dbms/src/Databases/DatabasesCommon.cpp index 006d65ede7b..6292c4f4149 100644 --- a/dbms/src/Databases/DatabasesCommon.cpp +++ b/dbms/src/Databases/DatabasesCommon.cpp @@ -26,7 +26,7 @@ namespace ErrorCodes String getTableDefinitionFromCreateQuery(const ASTPtr & query) { ASTPtr query_clone = query->clone(); - ASTCreateQuery & create = typeid_cast(*query_clone.get()); + auto & create = query_clone->as(); /// We remove everything that is not needed for ATTACH from the query. create.attach = true; @@ -62,7 +62,7 @@ std::pair createTableFromDefinition( ParserCreateQuery parser; ASTPtr ast = parseQuery(parser, definition.data(), definition.data() + definition.size(), description_for_error_message, 0); - ASTCreateQuery & ast_create_query = typeid_cast(*ast); + auto & ast_create_query = ast->as(); ast_create_query.attach = true; ast_create_query.database = database_name; diff --git a/dbms/src/Databases/IDatabase.h b/dbms/src/Databases/IDatabase.h index e6b67a87c9b..7c0d501fc60 100644 --- a/dbms/src/Databases/IDatabase.h +++ b/dbms/src/Databases/IDatabase.h @@ -1,16 +1,18 @@ #pragma once -#include #include +#include +#include +#include #include #include -#include -#include -#include #include -#include #include -#include +#include + +#include +#include +#include namespace DB @@ -21,9 +23,6 @@ class Context; class IStorage; using StoragePtr = std::shared_ptr; -class IAST; -using ASTPtr = std::shared_ptr; - struct Settings; @@ -157,4 +156,3 @@ using DatabasePtr = std::shared_ptr; using Databases = std::map; } - diff --git a/dbms/src/Interpreters/ActionsVisitor.cpp b/dbms/src/Interpreters/ActionsVisitor.cpp index 0dba4f6a163..5191e86c57e 100644 --- a/dbms/src/Interpreters/ActionsVisitor.cpp +++ b/dbms/src/Interpreters/ActionsVisitor.cpp @@ -84,11 +84,11 @@ SetPtr makeExplicitSet( auto getTupleTypeFromAst = [&context](const ASTPtr & tuple_ast) -> DataTypePtr { - auto ast_function = typeid_cast(tuple_ast.get()); - if (ast_function && ast_function->name == "tuple" && !ast_function->arguments->children.empty()) + const auto * func = tuple_ast->as(); + if (func && func->name == "tuple" && !func->arguments->children.empty()) { /// Won't parse all values of outer tuple. - auto element = ast_function->arguments->children.at(0); + auto element = func->arguments->children.at(0); std::pair value_raw = evaluateConstantExpression(element, context); return std::make_shared(DataTypes({value_raw.second})); } @@ -122,7 +122,7 @@ SetPtr makeExplicitSet( /// 1 in (1, 2); (1, 2) in ((1, 2), (3, 4)); etc. else if (left_tuple_depth + 1 == right_tuple_depth) { - ASTFunction * set_func = typeid_cast(right_arg.get()); + const auto * set_func = right_arg->as(); if (!set_func || set_func->name != "tuple") throw Exception("Incorrect type of 2nd argument for function " + node->name @@ -263,11 +263,10 @@ void ActionsVisitor::visit(const ASTPtr & ast) }; /// If the result of the calculation already exists in the block. - if ((typeid_cast(ast.get()) || typeid_cast(ast.get())) - && actions_stack.getSampleBlock().has(getColumnName())) + if ((ast->as() || ast->as()) && actions_stack.getSampleBlock().has(getColumnName())) return; - if (auto * identifier = typeid_cast(ast.get())) + if (const auto * identifier = ast->as()) { if (!only_consts && !actions_stack.getSampleBlock().has(getColumnName())) { @@ -288,7 +287,7 @@ void ActionsVisitor::visit(const ASTPtr & ast) actions_stack.addAction(ExpressionAction::addAliases({{identifier->name, identifier->alias}})); } } - else if (ASTFunction * node = typeid_cast(ast.get())) + else if (const auto * node = ast->as()) { if (node->name == "lambda") throw Exception("Unexpected lambda expression", ErrorCodes::UNEXPECTED_EXPRESSION); @@ -383,14 +382,14 @@ void ActionsVisitor::visit(const ASTPtr & ast) auto & child = node->arguments->children[arg]; auto child_column_name = child->getColumnName(); - ASTFunction * lambda = typeid_cast(child.get()); + const auto * lambda = child->as(); if (lambda && lambda->name == "lambda") { /// If the argument is a lambda expression, just remember its approximate type. if (lambda->arguments->children.size() != 2) throw Exception("lambda requires two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - ASTFunction * lambda_args_tuple = typeid_cast(lambda->arguments->children.at(0).get()); + const auto * lambda_args_tuple = lambda->arguments->children.at(0)->as(); if (!lambda_args_tuple || lambda_args_tuple->name != "tuple") throw Exception("First argument of lambda must be a tuple", ErrorCodes::TYPE_MISMATCH); @@ -454,12 +453,12 @@ void ActionsVisitor::visit(const ASTPtr & ast) { ASTPtr child = node->arguments->children[i]; - ASTFunction * lambda = typeid_cast(child.get()); + const auto * lambda = child->as(); if (lambda && lambda->name == "lambda") { const DataTypeFunction * lambda_type = typeid_cast(argument_types[i].get()); - ASTFunction * lambda_args_tuple = typeid_cast(lambda->arguments->children.at(0).get()); - ASTs lambda_arg_asts = lambda_args_tuple->arguments->children; + const auto * lambda_args_tuple = lambda->arguments->children.at(0)->as(); + const ASTs & lambda_arg_asts = lambda_args_tuple->arguments->children; NamesAndTypesList lambda_arguments; for (size_t j = 0; j < lambda_arg_asts.size(); ++j) @@ -517,7 +516,7 @@ void ActionsVisitor::visit(const ASTPtr & ast) ExpressionAction::applyFunction(function_builder, argument_names, getColumnName())); } } - else if (ASTLiteral * literal = typeid_cast(ast.get())) + else if (const auto * literal = ast->as()) { DataTypePtr type = applyVisitor(FieldToDataType(), literal->value); @@ -533,8 +532,7 @@ void ActionsVisitor::visit(const ASTPtr & ast) for (auto & child : ast->children) { /// Do not go to FROM, JOIN, UNION. - if (!typeid_cast(child.get()) - && !typeid_cast(child.get())) + if (!child->as() && !child->as()) visit(child); } } @@ -550,8 +548,8 @@ SetPtr ActionsVisitor::makeSet(const ASTFunction * node, const Block & sample_bl const ASTPtr & arg = args.children.at(1); /// If the subquery or table name for SELECT. - const ASTIdentifier * identifier = typeid_cast(arg.get()); - if (typeid_cast(arg.get()) || identifier) + const auto * identifier = arg->as(); + if (arg->as() || identifier) { auto set_key = PreparedSetKey::forSubquery(*arg); if (prepared_sets.count(set_key)) diff --git a/dbms/src/Interpreters/Aliases.h b/dbms/src/Interpreters/Aliases.h index b51dc80ae5f..52159442224 100644 --- a/dbms/src/Interpreters/Aliases.h +++ b/dbms/src/Interpreters/Aliases.h @@ -1,15 +1,13 @@ #pragma once -#include -#include #include +#include + +#include namespace DB { -class IAST; -using ASTPtr = std::shared_ptr; - using Aliases = std::unordered_map; } diff --git a/dbms/src/Interpreters/AnalyzedJoin.cpp b/dbms/src/Interpreters/AnalyzedJoin.cpp index ce44430a5c3..333e82690dc 100644 --- a/dbms/src/Interpreters/AnalyzedJoin.cpp +++ b/dbms/src/Interpreters/AnalyzedJoin.cpp @@ -45,7 +45,7 @@ ExpressionActionsPtr AnalyzedJoin::createJoinedBlockActions( if (!join) return nullptr; - const auto & join_params = static_cast(*join->table_join); + const auto & join_params = join->table_join->as(); /// Create custom expression list with join keys from right table. auto expression_list = std::make_shared(); diff --git a/dbms/src/Interpreters/ArrayJoinedColumnsVisitor.h b/dbms/src/Interpreters/ArrayJoinedColumnsVisitor.h index fc603ea3131..204e4324c53 100644 --- a/dbms/src/Interpreters/ArrayJoinedColumnsVisitor.h +++ b/dbms/src/Interpreters/ArrayJoinedColumnsVisitor.h @@ -40,11 +40,10 @@ public: static bool needChildVisit(ASTPtr & node, const ASTPtr & child) { - if (typeid_cast(node.get())) + if (node->as()) return false; - if (typeid_cast(child.get()) || - typeid_cast(child.get())) + if (child->as() || child->as()) return false; return true; @@ -52,9 +51,9 @@ public: static void visit(ASTPtr & ast, Data & data) { - if (auto * t = typeid_cast(ast.get())) + if (const auto * t = ast->as()) visit(*t, ast, data); - if (auto * t = typeid_cast(ast.get())) + if (const auto * t = ast->as()) visit(*t, ast, data); } @@ -73,7 +72,7 @@ private: const String nested_table_name = ast->getColumnName(); const String nested_table_alias = ast->getAliasOrColumnName(); - if (nested_table_alias == nested_table_name && !isIdentifier(ast)) + if (nested_table_alias == nested_table_name && !ast->as()) throw Exception("No alias for non-trivial value in ARRAY JOIN: " + nested_table_name, ErrorCodes::ALIAS_REQUIRED); if (data.array_join_alias_to_name.count(nested_table_alias) || data.aliases.count(nested_table_alias)) diff --git a/dbms/src/Interpreters/ClusterProxy/SelectStreamFactory.cpp b/dbms/src/Interpreters/ClusterProxy/SelectStreamFactory.cpp index b34b3d34633..0a1243758cf 100644 --- a/dbms/src/Interpreters/ClusterProxy/SelectStreamFactory.cpp +++ b/dbms/src/Interpreters/ClusterProxy/SelectStreamFactory.cpp @@ -98,7 +98,7 @@ void SelectStreamFactory::createForShard( if (table_func_ptr) { - auto table_function = static_cast(table_func_ptr.get()); + const auto * table_function = table_func_ptr->as(); main_table_storage = TableFunctionFactory::instance().get(table_function->name, context)->execute(table_func_ptr, context); } else diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index 2c1a9a58567..42ef341a2f8 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -892,8 +892,7 @@ StoragePtr Context::executeTableFunction(const ASTPtr & table_expression) if (!res) { - TableFunctionPtr table_function_ptr = TableFunctionFactory::instance().get( - typeid_cast(table_expression.get())->name, *this); + TableFunctionPtr table_function_ptr = TableFunctionFactory::instance().get(table_expression->as()->name, *this); /// Run it and remember the result res = table_function_ptr->execute(table_expression, *this); diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index 8b4e66094ff..4f90a50e349 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -1,23 +1,24 @@ #pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include #include #include #include #include #include -#include -#include #include - -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include namespace Poco @@ -68,8 +69,6 @@ class IStorage; class ITableFunction; using StoragePtr = std::shared_ptr; using Tables = std::map; -class IAST; -using ASTPtr = std::shared_ptr; class IBlockInputStream; class IBlockOutputStream; using BlockInputStreamPtr = std::shared_ptr; diff --git a/dbms/src/Interpreters/CrossToInnerJoinVisitor.cpp b/dbms/src/Interpreters/CrossToInnerJoinVisitor.cpp index ede49b90538..dc46c95ace9 100644 --- a/dbms/src/Interpreters/CrossToInnerJoinVisitor.cpp +++ b/dbms/src/Interpreters/CrossToInnerJoinVisitor.cpp @@ -36,13 +36,13 @@ struct JoinedTable JoinedTable(ASTPtr table_element) { - element = typeid_cast(table_element.get()); + element = table_element->as(); if (!element) throw Exception("Logical error: TablesInSelectQueryElement expected", ErrorCodes::LOGICAL_ERROR); if (element->table_join) { - join = typeid_cast(element->table_join.get()); + join = element->table_join->as(); if (join->kind == ASTTableJoin::Kind::Cross || join->kind == ASTTableJoin::Kind::Comma) { @@ -56,7 +56,7 @@ struct JoinedTable if (element->table_expression) { - auto & expr = typeid_cast(*element->table_expression); + const auto & expr = element->table_expression->as(); table = DatabaseAndTableWithAlias(expr); } @@ -105,7 +105,7 @@ public: for (auto & child : node.arguments->children) { - if (auto func = typeid_cast(child.get())) + if (const auto * func = child->as()) visit(*func, child); else ands_only = false; @@ -160,8 +160,8 @@ private: if (node.arguments->children.size() != 2) return false; - auto left = typeid_cast(node.arguments->children[0].get()); - auto right = typeid_cast(node.arguments->children[1].get()); + const auto * left = node.arguments->children[0]->as(); + const auto * right = node.arguments->children[1]->as(); if (!left || !right) return false; @@ -213,7 +213,7 @@ bool getTables(ASTSelectQuery & select, std::vector & joined_tables if (!select.tables) return false; - auto tables = typeid_cast(select.tables.get()); + const auto * tables = select.tables->as(); if (!tables) return false; @@ -232,7 +232,7 @@ bool getTables(ASTSelectQuery & select, std::vector & joined_tables if (num_tables > 2 && t.has_using) throw Exception("Multiple CROSS/COMMA JOIN do not support USING", ErrorCodes::NOT_IMPLEMENTED); - if (ASTTableJoin * join = t.join) + if (auto * join = t.join) if (join->kind == ASTTableJoin::Kind::Comma) ++num_comma; } @@ -244,7 +244,7 @@ bool getTables(ASTSelectQuery & select, std::vector & joined_tables void CrossToInnerJoinMatcher::visit(ASTPtr & ast, Data & data) { - if (auto * t = typeid_cast(ast.get())) + if (auto * t = ast->as()) visit(*t, ast, data); } diff --git a/dbms/src/Interpreters/DDLWorker.cpp b/dbms/src/Interpreters/DDLWorker.cpp index 35639bf213b..7300adb43cf 100644 --- a/dbms/src/Interpreters/DDLWorker.cpp +++ b/dbms/src/Interpreters/DDLWorker.cpp @@ -449,6 +449,7 @@ void DDLWorker::parseQueryAndResolveHost(DDLTask & task) task.query = parseQuery(parser_query, begin, end, description, 0); } + // XXX: serious design flaw since `ASTQueryWithOnCluster` is not inherited from `IAST`! if (!task.query || !(task.query_on_cluster = dynamic_cast(task.query.get()))) throw Exception("Received unknown DDL query", ErrorCodes::UNKNOWN_TYPE_OF_QUERY); @@ -612,7 +613,7 @@ void DDLWorker::processTask(DDLTask & task, const ZooKeeperPtr & zookeeper) String rewritten_query = queryToString(rewritten_ast); LOG_DEBUG(log, "Executing query: " << rewritten_query); - if (auto ast_alter = dynamic_cast(rewritten_ast.get())) + if (const auto * ast_alter = rewritten_ast->as()) { processTaskAlter(task, ast_alter, rewritten_query, task.entry_path, zookeeper); } @@ -1211,7 +1212,8 @@ BlockIO executeDDLQueryOnCluster(const ASTPtr & query_ptr_, const Context & cont ASTPtr query_ptr = query_ptr_->clone(); ASTQueryWithOutput::resetOutputASTIfExist(*query_ptr); - auto query = dynamic_cast(query_ptr.get()); + // XXX: serious design flaw since `ASTQueryWithOnCluster` is not inherited from `IAST`! + auto * query = dynamic_cast(query_ptr.get()); if (!query) { throw Exception("Distributed execution is not supported for such DDL queries", ErrorCodes::NOT_IMPLEMENTED); @@ -1220,7 +1222,7 @@ BlockIO executeDDLQueryOnCluster(const ASTPtr & query_ptr_, const Context & cont if (!context.getSettingsRef().allow_distributed_ddl) throw Exception("Distributed DDL queries are prohibited for the user", ErrorCodes::QUERY_IS_PROHIBITED); - if (auto query_alter = dynamic_cast(query_ptr.get())) + if (const auto * query_alter = query_ptr->as()) { for (const auto & command : query_alter->command_list->commands) { diff --git a/dbms/src/Interpreters/DatabaseAndTableWithAlias.cpp b/dbms/src/Interpreters/DatabaseAndTableWithAlias.cpp index 52b05fc5933..47b0ead07cc 100644 --- a/dbms/src/Interpreters/DatabaseAndTableWithAlias.cpp +++ b/dbms/src/Interpreters/DatabaseAndTableWithAlias.cpp @@ -27,7 +27,7 @@ DatabaseAndTableWithAlias::DatabaseAndTableWithAlias(const ASTIdentifier & ident DatabaseAndTableWithAlias::DatabaseAndTableWithAlias(const ASTPtr & node, const String & current_database) { - const auto * identifier = typeid_cast(node.get()); + const auto * identifier = node->as(); if (!identifier) throw Exception("Logical error: identifier expected", ErrorCodes::LOGICAL_ERROR); @@ -78,10 +78,10 @@ std::vector getSelectTablesExpression(const ASTSelec for (const auto & child : select_query.tables->children) { - ASTTablesInSelectQueryElement * tables_element = static_cast(child.get()); + const auto * tables_element = child->as(); if (tables_element->table_expression) - tables_expression.emplace_back(static_cast(tables_element->table_expression.get())); + tables_expression.emplace_back(tables_element->table_expression->as()); } return tables_expression; @@ -92,17 +92,16 @@ static const ASTTableExpression * getTableExpression(const ASTSelectQuery & sele if (!select.tables) return {}; - ASTTablesInSelectQuery & tables_in_select_query = static_cast(*select.tables); + const auto & tables_in_select_query = select.tables->as(); if (tables_in_select_query.children.size() <= table_number) return {}; - ASTTablesInSelectQueryElement & tables_element = - static_cast(*tables_in_select_query.children[table_number]); + const auto & tables_element = tables_in_select_query.children[table_number]->as(); if (!tables_element.table_expression) return {}; - return static_cast(tables_element.table_expression.get()); + return tables_element.table_expression->as(); } std::vector getDatabaseAndTables(const ASTSelectQuery & select_query, const String & current_database) @@ -125,7 +124,7 @@ std::optional getDatabaseAndTable(const ASTSelectQuer return {}; ASTPtr database_and_table_name = table_expression->database_and_table_name; - if (!database_and_table_name || !isIdentifier(database_and_table_name)) + if (!database_and_table_name || !database_and_table_name->as()) return {}; return DatabaseAndTableWithAlias(database_and_table_name); @@ -142,7 +141,7 @@ ASTPtr extractTableExpression(const ASTSelectQuery & select, size_t table_number return table_expression->table_function; if (table_expression->subquery) - return static_cast(table_expression->subquery.get())->children[0]; + return table_expression->subquery->children[0]; } return nullptr; diff --git a/dbms/src/Interpreters/DatabaseAndTableWithAlias.h b/dbms/src/Interpreters/DatabaseAndTableWithAlias.h index 0f1cbe8bbc7..22b03a5ed44 100644 --- a/dbms/src/Interpreters/DatabaseAndTableWithAlias.h +++ b/dbms/src/Interpreters/DatabaseAndTableWithAlias.h @@ -1,18 +1,16 @@ #pragma once +#include +#include +#include + #include #include -#include -#include - namespace DB { -class IAST; -using ASTPtr = std::shared_ptr; - class ASTSelectQuery; class ASTIdentifier; struct ASTTableExpression; diff --git a/dbms/src/Interpreters/ExecuteScalarSubqueriesVisitor.cpp b/dbms/src/Interpreters/ExecuteScalarSubqueriesVisitor.cpp index b6cbaaf181b..57a08994426 100644 --- a/dbms/src/Interpreters/ExecuteScalarSubqueriesVisitor.cpp +++ b/dbms/src/Interpreters/ExecuteScalarSubqueriesVisitor.cpp @@ -41,19 +41,17 @@ static ASTPtr addTypeConversion(std::unique_ptr && ast, const String bool ExecuteScalarSubqueriesMatcher::needChildVisit(ASTPtr & node, const ASTPtr & child) { /// Processed - if (typeid_cast(node.get()) || - typeid_cast(node.get())) + if (node->as() || node->as()) return false; /// Don't descend into subqueries in FROM section - if (typeid_cast(node.get())) + if (node->as()) return false; - if (typeid_cast(node.get())) + if (node->as()) { /// Do not go to FROM, JOIN, UNION. - if (typeid_cast(child.get()) || - typeid_cast(child.get())) + if (child->as() || child->as()) return false; } @@ -62,9 +60,9 @@ bool ExecuteScalarSubqueriesMatcher::needChildVisit(ASTPtr & node, const ASTPtr void ExecuteScalarSubqueriesMatcher::visit(ASTPtr & ast, Data & data) { - if (auto * t = typeid_cast(ast.get())) + if (const auto * t = ast->as()) visit(*t, ast, data); - if (auto * t = typeid_cast(ast.get())) + if (const auto * t = ast->as()) visit(*t, ast, data); } @@ -147,7 +145,7 @@ void ExecuteScalarSubqueriesMatcher::visit(const ASTFunction & func, ASTPtr & as out.push_back(&child); else for (size_t i = 0, size = func.arguments->children.size(); i < size; ++i) - if (i != 1 || !typeid_cast(func.arguments->children[i].get())) + if (i != 1 || !func.arguments->children[i]->as()) out.push_back(&func.arguments->children[i]); } } diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.cpp b/dbms/src/Interpreters/ExpressionAnalyzer.cpp index 4a2e62fe1ec..8e69d5ecfee 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.cpp +++ b/dbms/src/Interpreters/ExpressionAnalyzer.cpp @@ -90,8 +90,6 @@ ExpressionAnalyzer::ExpressionAnalyzer( storage = syntax->storage; rewrite_subqueries = syntax->rewrite_subqueries; - select_query = typeid_cast(query.get()); - if (!additional_source_columns.empty()) { source_columns.insert(source_columns.end(), additional_source_columns.begin(), additional_source_columns.end()); @@ -130,6 +128,8 @@ void ExpressionAnalyzer::analyzeAggregation() * Everything below (compiling temporary ExpressionActions) - only for the purpose of query analysis (type output). */ + auto * select_query = query->as(); + if (select_query && (select_query->group_expression_list || select_query->having_expression)) has_aggregation = true; @@ -149,7 +149,7 @@ void ExpressionAnalyzer::analyzeAggregation() const ASTTablesInSelectQueryElement * join = select_query->join(); if (join) { - const auto table_join = static_cast(*join->table_join); + const auto & table_join = join->table_join->as(); if (table_join.using_expression_list) getRootActions(table_join.using_expression_list, true, temp_actions); if (table_join.on_expression) @@ -250,6 +250,8 @@ void ExpressionAnalyzer::initGlobalSubqueriesAndExternalTables() void ExpressionAnalyzer::makeSetsForIndex() { + const auto * select_query = query->as(); + if (storage && select_query && storage->supportsIndexForIn()) { if (select_query->where_expression) @@ -288,18 +290,18 @@ void ExpressionAnalyzer::makeSetsForIndexImpl(const ASTPtr & node) for (auto & child : node->children) { /// Don't descend into subqueries. - if (typeid_cast(child.get())) + if (child->as()) continue; /// Don't descend into lambda functions - const ASTFunction * func = typeid_cast(child.get()); + const auto * func = child->as(); if (func && func->name == "lambda") continue; makeSetsForIndexImpl(child); } - const ASTFunction * func = typeid_cast(node.get()); + const auto * func = node->as(); if (func && functionIsInOperator(func->name)) { const IAST & args = *func->arguments; @@ -307,7 +309,7 @@ void ExpressionAnalyzer::makeSetsForIndexImpl(const ASTPtr & node) if (storage && storage->mayBenefitFromIndexForIn(args.children.at(0), context)) { const ASTPtr & arg = args.children.at(1); - if (typeid_cast(arg.get()) || isIdentifier(arg)) + if (arg->as() || arg->as()) { if (settings.use_index_for_in_with_subqueries) tryMakeSetForIndexFromSubquery(arg); @@ -365,6 +367,8 @@ void ExpressionAnalyzer::getActionsFromJoinKeys(const ASTTableJoin & table_join, void ExpressionAnalyzer::getAggregates(const ASTPtr & ast, ExpressionActionsPtr & actions) { + const auto * select_query = query->as(); + /// There can not be aggregate functions inside the WHERE and PREWHERE. if (select_query && (ast.get() == select_query->where_expression.get() || ast.get() == select_query->prewhere_expression.get())) { @@ -379,7 +383,7 @@ void ExpressionAnalyzer::getAggregates(const ASTPtr & ast, ExpressionActionsPtr return; } - const ASTFunction * node = typeid_cast(ast.get()); + const auto * node = ast->as(); if (node && AggregateFunctionFactory::instance().isAggregateFunctionName(node->name)) { has_aggregation = true; @@ -414,8 +418,7 @@ void ExpressionAnalyzer::getAggregates(const ASTPtr & ast, ExpressionActionsPtr else { for (const auto & child : ast->children) - if (!typeid_cast(child.get()) - && !typeid_cast(child.get())) + if (!child->as() && !child->as()) getAggregates(child, actions); } } @@ -423,21 +426,22 @@ void ExpressionAnalyzer::getAggregates(const ASTPtr & ast, ExpressionActionsPtr void ExpressionAnalyzer::assertNoAggregates(const ASTPtr & ast, const char * description) { - const ASTFunction * node = typeid_cast(ast.get()); + const auto * node = ast->as(); if (node && AggregateFunctionFactory::instance().isAggregateFunctionName(node->name)) throw Exception("Aggregate function " + node->getColumnName() + " is found " + String(description) + " in query", ErrorCodes::ILLEGAL_AGGREGATION); for (const auto & child : ast->children) - if (!typeid_cast(child.get()) - && !typeid_cast(child.get())) + if (!child->as() && !child->as()) assertNoAggregates(child, description); } void ExpressionAnalyzer::assertSelect() const { + const auto * select_query = query->as(); + if (!select_query) throw Exception("Not a select query", ErrorCodes::LOGICAL_ERROR); } @@ -475,6 +479,8 @@ void ExpressionAnalyzer::addMultipleArrayJoinAction(ExpressionActionsPtr & actio bool ExpressionAnalyzer::appendArrayJoin(ExpressionActionsChain & chain, bool only_types) { + const auto * select_query = query->as(); + assertSelect(); bool is_array_join_left; @@ -520,6 +526,8 @@ static void appendRequiredColumns(NameSet & required_columns, const Block & samp bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_types) { + const auto * select_query = query->as(); + assertSelect(); if (!select_query->join()) @@ -528,8 +536,8 @@ bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_ty initChain(chain, source_columns); ExpressionActionsChain::Step & step = chain.steps.back(); - const auto & join_element = static_cast(*select_query->join()); - auto & join_params = static_cast(*join_element.table_join); + const auto & join_element = select_query->join()->as(); + auto & join_params = join_element.table_join->as(); if (join_params.strictness == ASTTableJoin::Strictness::Unspecified && join_params.kind != ASTTableJoin::Kind::Cross) { @@ -541,7 +549,7 @@ bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_ty throw Exception("Expected ANY or ALL in JOIN section, because setting (join_default_strictness) is empty", DB::ErrorCodes::EXPECTED_ALL_OR_ANY); } - const auto & table_to_join = static_cast(*join_element.table_expression); + const auto & table_to_join = join_element.table_expression->as(); getActionsFromJoinKeys(join_params, only_types, step.actions); @@ -559,7 +567,7 @@ bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_ty if (table) { - StorageJoin * storage_join = dynamic_cast(table.get()); + auto * storage_join = dynamic_cast(table.get()); if (storage_join) { @@ -624,6 +632,8 @@ bool ExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, bool only_ty bool ExpressionAnalyzer::appendPrewhere( ExpressionActionsChain & chain, bool only_types, const Names & additional_required_columns) { + const auto * select_query = query->as(); + assertSelect(); if (!select_query->prewhere_expression) @@ -697,6 +707,8 @@ bool ExpressionAnalyzer::appendPrewhere( bool ExpressionAnalyzer::appendWhere(ExpressionActionsChain & chain, bool only_types) { + const auto * select_query = query->as(); + assertSelect(); if (!select_query->where_expression) @@ -715,6 +727,8 @@ bool ExpressionAnalyzer::appendWhere(ExpressionActionsChain & chain, bool only_t bool ExpressionAnalyzer::appendGroupBy(ExpressionActionsChain & chain, bool only_types) { + const auto * select_query = query->as(); + assertAggregation(); if (!select_query->group_expression_list) @@ -735,6 +749,8 @@ bool ExpressionAnalyzer::appendGroupBy(ExpressionActionsChain & chain, bool only void ExpressionAnalyzer::appendAggregateFunctionsArguments(ExpressionActionsChain & chain, bool only_types) { + const auto * select_query = query->as(); + assertAggregation(); initChain(chain, source_columns); @@ -759,6 +775,8 @@ void ExpressionAnalyzer::appendAggregateFunctionsArguments(ExpressionActionsChai bool ExpressionAnalyzer::appendHaving(ExpressionActionsChain & chain, bool only_types) { + const auto * select_query = query->as(); + assertAggregation(); if (!select_query->having_expression) @@ -775,6 +793,8 @@ bool ExpressionAnalyzer::appendHaving(ExpressionActionsChain & chain, bool only_ void ExpressionAnalyzer::appendSelect(ExpressionActionsChain & chain, bool only_types) { + const auto * select_query = query->as(); + assertSelect(); initChain(chain, aggregated_columns); @@ -788,6 +808,8 @@ void ExpressionAnalyzer::appendSelect(ExpressionActionsChain & chain, bool only_ bool ExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain, bool only_types) { + const auto * select_query = query->as(); + assertSelect(); if (!select_query->order_expression_list) @@ -801,7 +823,7 @@ bool ExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain, bool only ASTs asts = select_query->order_expression_list->children; for (size_t i = 0; i < asts.size(); ++i) { - ASTOrderByElement * ast = typeid_cast(asts[i].get()); + const auto * ast = asts[i]->as(); if (!ast || ast->children.size() < 1) throw Exception("Bad order expression AST", ErrorCodes::UNKNOWN_TYPE_OF_AST_NODE); ASTPtr order_expression = ast->children.at(0); @@ -813,6 +835,8 @@ bool ExpressionAnalyzer::appendOrderBy(ExpressionActionsChain & chain, bool only bool ExpressionAnalyzer::appendLimitBy(ExpressionActionsChain & chain, bool only_types) { + const auto * select_query = query->as(); + assertSelect(); if (!select_query->limit_by_expression_list) @@ -831,6 +855,8 @@ bool ExpressionAnalyzer::appendLimitBy(ExpressionActionsChain & chain, bool only void ExpressionAnalyzer::appendProjectResult(ExpressionActionsChain & chain) const { + const auto * select_query = query->as(); + assertSelect(); initChain(chain, aggregated_columns); @@ -864,7 +890,7 @@ void ExpressionAnalyzer::appendExpression(ExpressionActionsChain & chain, const void ExpressionAnalyzer::getActionsBeforeAggregation(const ASTPtr & ast, ExpressionActionsPtr & actions, bool no_subqueries) { - ASTFunction * node = typeid_cast(ast.get()); + const auto * node = ast->as(); if (node && AggregateFunctionFactory::instance().isAggregateFunctionName(node->name)) for (auto & argument : node->arguments->children) @@ -883,7 +909,7 @@ ExpressionActionsPtr ExpressionAnalyzer::getActions(bool add_aliases, bool proje ASTs asts; - if (auto node = typeid_cast(query.get())) + if (const auto * node = query->as()) asts = node->children; else asts = ASTs(1, query); @@ -965,21 +991,6 @@ void ExpressionAnalyzer::collectUsedColumns() if (columns_context.has_table_join) { const AnalyzedJoin & analyzed_join = analyzedJoin(); -#if 0 - std::cerr << "key_names_left: "; - for (const auto & name : analyzed_join.key_names_left) - std::cerr << "'" << name << "' "; - std::cerr << "key_names_right: "; - for (const auto & name : analyzed_join.key_names_right) - std::cerr << "'" << name << "' "; - std::cerr << "columns_from_joined_table: "; - for (const auto & column : analyzed_join.columns_from_joined_table) - std::cerr << "'" << column.name_and_type.name << '/' << column.original_name << "' "; - std::cerr << "available_joined_columns: "; - for (const auto & column : analyzed_join.available_joined_columns) - std::cerr << "'" << column.name_and_type.name << '/' << column.original_name << "' "; - std::cerr << std::endl; -#endif NameSet avaliable_columns; for (const auto & name : source_columns) avaliable_columns.insert(name.name); @@ -1014,6 +1025,8 @@ void ExpressionAnalyzer::collectUsedColumns() required.insert(column_name_type.name); } + const auto * select_query = query->as(); + /// You need to read at least one column to find the number of rows. if (select_query && required.empty()) required.insert(ExpressionActions::getSmallestColumn(source_columns)); diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.h b/dbms/src/Interpreters/ExpressionAnalyzer.h index 83598aba72d..b9ec5c3ca70 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.h +++ b/dbms/src/Interpreters/ExpressionAnalyzer.h @@ -1,9 +1,11 @@ #pragma once +#include #include #include -#include #include +#include + namespace DB { @@ -15,9 +17,6 @@ struct ExpressionActionsChain; class ExpressionActions; using ExpressionActionsPtr = std::shared_ptr; -class IAST; -using ASTPtr = std::shared_ptr; -using ASTs = std::vector; struct ASTTableJoin; class IBlockInputStream; @@ -211,7 +210,6 @@ public: private: ASTPtr query; - ASTSelectQuery * select_query; const Context & context; const ExtractedSettings settings; StoragePtr storage; /// The main table in FROM clause, if exists. diff --git a/dbms/src/Interpreters/ExternalTablesVisitor.h b/dbms/src/Interpreters/ExternalTablesVisitor.h index c792b3a1856..3ff5f38f84d 100644 --- a/dbms/src/Interpreters/ExternalTablesVisitor.h +++ b/dbms/src/Interpreters/ExternalTablesVisitor.h @@ -22,7 +22,7 @@ public: static void visit(ASTPtr & ast, Data & data) { - if (auto * t = typeid_cast(ast.get())) + if (const auto * t = ast->as()) visit(*t, ast, data); } diff --git a/dbms/src/Interpreters/GlobalSubqueriesVisitor.h b/dbms/src/Interpreters/GlobalSubqueriesVisitor.h index 644ed7fe5d2..c99f8d95e49 100644 --- a/dbms/src/Interpreters/GlobalSubqueriesVisitor.h +++ b/dbms/src/Interpreters/GlobalSubqueriesVisitor.h @@ -56,12 +56,12 @@ public: ASTPtr table_name; ASTPtr subquery_or_table_name; - if (isIdentifier(subquery_or_table_name_or_table_expression)) + if (subquery_or_table_name_or_table_expression->as()) { table_name = subquery_or_table_name_or_table_expression; subquery_or_table_name = table_name; } - else if (auto ast_table_expr = typeid_cast(subquery_or_table_name_or_table_expression.get())) + else if (const auto * ast_table_expr = subquery_or_table_name_or_table_expression->as()) { if (ast_table_expr->database_and_table_name) { @@ -74,7 +74,7 @@ public: subquery_or_table_name = subquery; } } - else if (typeid_cast(subquery_or_table_name_or_table_expression.get())) + else if (subquery_or_table_name_or_table_expression->as()) { subquery = subquery_or_table_name_or_table_expression; subquery_or_table_name = subquery; @@ -115,7 +115,7 @@ public: auto database_and_table_name = createTableIdentifier("", external_table_name); - if (auto ast_table_expr = typeid_cast(subquery_or_table_name_or_table_expression.get())) + if (auto * ast_table_expr = subquery_or_table_name_or_table_expression->as()) { ast_table_expr->subquery.reset(); ast_table_expr->database_and_table_name = database_and_table_name; @@ -140,16 +140,16 @@ public: static void visit(ASTPtr & ast, Data & data) { - if (auto * t = typeid_cast(ast.get())) + if (auto * t = ast->as()) visit(*t, ast, data); - if (auto * t = typeid_cast(ast.get())) + if (auto * t = ast->as()) visit(*t, ast, data); } static bool needChildVisit(ASTPtr &, const ASTPtr & child) { /// We do not go into subqueries. - if (typeid_cast(child.get())) + if (child->as()) return false; return true; } @@ -168,8 +168,7 @@ private: /// GLOBAL JOIN static void visit(ASTTablesInSelectQueryElement & table_elem, ASTPtr &, Data & data) { - if (table_elem.table_join - && static_cast(*table_elem.table_join).locality == ASTTableJoin::Locality::Global) + if (table_elem.table_join && table_elem.table_join->as().locality == ASTTableJoin::Locality::Global) { data.addExternalStorage(table_elem.table_expression); data.has_global_subqueries = true; diff --git a/dbms/src/Interpreters/IdentifierSemantic.cpp b/dbms/src/Interpreters/IdentifierSemantic.cpp index e4ea91534af..361462c0d1d 100644 --- a/dbms/src/Interpreters/IdentifierSemantic.cpp +++ b/dbms/src/Interpreters/IdentifierSemantic.cpp @@ -15,7 +15,7 @@ std::optional IdentifierSemantic::getColumnName(const ASTIdentifier & no std::optional IdentifierSemantic::getColumnName(const ASTPtr & ast) { if (ast) - if (auto id = typeid_cast(ast.get())) + if (const auto * id = ast->as()) if (!id->semantic->special) return id->name; return {}; @@ -31,7 +31,7 @@ std::optional IdentifierSemantic::getTableName(const ASTIdentifier & nod std::optional IdentifierSemantic::getTableName(const ASTPtr & ast) { if (ast) - if (auto id = typeid_cast(ast.get())) + if (const auto * id = ast->as()) if (id->semantic->special) return id->name; return {}; @@ -144,7 +144,7 @@ void IdentifierSemantic::setColumnLongName(ASTIdentifier & identifier, const Dat String IdentifierSemantic::columnNormalName(const ASTIdentifier & identifier, const DatabaseAndTableWithAlias & db_and_table) { ASTPtr copy = identifier.clone(); - setColumnNormalName(typeid_cast(*copy), db_and_table); + setColumnNormalName(copy->as(), db_and_table); return copy->getAliasOrColumnName(); } diff --git a/dbms/src/Interpreters/InJoinSubqueriesPreprocessor.cpp b/dbms/src/Interpreters/InJoinSubqueriesPreprocessor.cpp index e0be0d068e0..9e814b00b63 100644 --- a/dbms/src/Interpreters/InJoinSubqueriesPreprocessor.cpp +++ b/dbms/src/Interpreters/InJoinSubqueriesPreprocessor.cpp @@ -30,7 +30,7 @@ namespace template void forEachNonGlobalSubquery(IAST * node, F && f) { - if (ASTFunction * function = typeid_cast(node)) + if (auto * function = node->as()) { if (function->name == "in" || function->name == "notIn") { @@ -40,14 +40,14 @@ void forEachNonGlobalSubquery(IAST * node, F && f) /// Pass into other functions, as subquery could be in aggregate or in lambda functions. } - else if (ASTTablesInSelectQueryElement * join = typeid_cast(node)) + else if (const auto * join = node->as()) { if (join->table_join && join->table_expression) { - auto & table_join = static_cast(*join->table_join); + auto & table_join = join->table_join->as(); if (table_join.locality != ASTTableJoin::Locality::Global) { - auto & subquery = static_cast(*join->table_expression).subquery; + auto & subquery = join->table_expression->as()->subquery; if (subquery) f(subquery.get(), nullptr, &table_join); } @@ -59,7 +59,7 @@ void forEachNonGlobalSubquery(IAST * node, F && f) /// Descent into all children, but not into subqueries of other kind (scalar subqueries), that are irrelevant to us. for (auto & child : node->children) - if (!typeid_cast(child.get())) + if (!child->as()) forEachNonGlobalSubquery(child.get(), f); } @@ -69,7 +69,7 @@ void forEachNonGlobalSubquery(IAST * node, F && f) template void forEachTable(IAST * node, F && f) { - if (auto table_expression = typeid_cast(node)) + if (auto * table_expression = node->as()) { auto & database_and_table = table_expression->database_and_table_name; if (database_and_table) @@ -103,15 +103,15 @@ void InJoinSubqueriesPreprocessor::process(ASTSelectQuery * query) const if (!query->tables) return; - ASTTablesInSelectQuery & tables_in_select_query = static_cast(*query->tables); + const auto & tables_in_select_query = query->tables->as(); if (tables_in_select_query.children.empty()) return; - ASTTablesInSelectQueryElement & tables_element = static_cast(*tables_in_select_query.children[0]); + const auto & tables_element = tables_in_select_query.children[0]->as(); if (!tables_element.table_expression) return; - ASTTableExpression * table_expression = static_cast(tables_element.table_expression.get()); + const auto * table_expression = tables_element.table_expression->as(); /// If not ordinary table, skip it. if (!table_expression->database_and_table_name) @@ -143,7 +143,7 @@ void InJoinSubqueriesPreprocessor::process(ASTSelectQuery * query) const { if (function) { - ASTFunction * concrete = static_cast(function); + auto * concrete = function->as(); if (concrete->name == "in") concrete->name = "globalIn"; @@ -157,7 +157,7 @@ void InJoinSubqueriesPreprocessor::process(ASTSelectQuery * query) const throw Exception("Logical error: unexpected function name " + concrete->name, ErrorCodes::LOGICAL_ERROR); } else if (table_join) - static_cast(*table_join).locality = ASTTableJoin::Locality::Global; + table_join->as().locality = ASTTableJoin::Locality::Global; else throw Exception("Logical error: unexpected AST node", ErrorCodes::LOGICAL_ERROR); } diff --git a/dbms/src/Interpreters/InterpreterAlterQuery.cpp b/dbms/src/Interpreters/InterpreterAlterQuery.cpp index c80001dc2fc..8751ff067b1 100644 --- a/dbms/src/Interpreters/InterpreterAlterQuery.cpp +++ b/dbms/src/Interpreters/InterpreterAlterQuery.cpp @@ -30,7 +30,7 @@ InterpreterAlterQuery::InterpreterAlterQuery(const ASTPtr & query_ptr_, const Co BlockIO InterpreterAlterQuery::execute() { - auto & alter = typeid_cast(*query_ptr); + const auto & alter = query_ptr->as(); if (!alter.cluster.empty()) return executeDDLQueryOnCluster(query_ptr, context, {alter.database}); diff --git a/dbms/src/Interpreters/InterpreterAlterQuery.h b/dbms/src/Interpreters/InterpreterAlterQuery.h index bd9f3a89d6a..776409a225d 100644 --- a/dbms/src/Interpreters/InterpreterAlterQuery.h +++ b/dbms/src/Interpreters/InterpreterAlterQuery.h @@ -1,14 +1,13 @@ #pragma once #include +#include namespace DB { class Context; -class IAST; -using ASTPtr = std::shared_ptr; /** Allows you add or remove a column in the table. * It also allows you to manipulate the partitions of the MergeTree family tables. diff --git a/dbms/src/Interpreters/InterpreterCheckQuery.cpp b/dbms/src/Interpreters/InterpreterCheckQuery.cpp index 84788fd0685..c99c74fa33a 100644 --- a/dbms/src/Interpreters/InterpreterCheckQuery.cpp +++ b/dbms/src/Interpreters/InterpreterCheckQuery.cpp @@ -19,8 +19,8 @@ InterpreterCheckQuery::InterpreterCheckQuery(const ASTPtr & query_ptr_, const Co BlockIO InterpreterCheckQuery::execute() { - ASTCheckQuery & alter = typeid_cast(*query_ptr); - String & table_name = alter.table; + const auto & alter = query_ptr->as(); + const String & table_name = alter.table; String database_name = alter.database.empty() ? context.getCurrentDatabase() : alter.database; StoragePtr table = context.getTable(database_name, table_name); diff --git a/dbms/src/Interpreters/InterpreterCreateQuery.cpp b/dbms/src/Interpreters/InterpreterCreateQuery.cpp index e4eb78c2b01..4e669faa512 100644 --- a/dbms/src/Interpreters/InterpreterCreateQuery.cpp +++ b/dbms/src/Interpreters/InterpreterCreateQuery.cpp @@ -197,7 +197,7 @@ static ColumnsDeclarationAndModifiers parseColumns(const ASTExpressionList & col for (const auto & ast : column_list_ast.children) { - auto & col_decl = typeid_cast(*ast); + auto & col_decl = ast->as(); DataTypePtr column_type = nullptr; if (col_decl.type) @@ -240,7 +240,7 @@ static ColumnsDeclarationAndModifiers parseColumns(const ASTExpressionList & col if (col_decl.comment) { - if (auto comment_str = typeid_cast(*col_decl.comment).value.get(); !comment_str.empty()) + if (auto comment_str = col_decl.comment->as().value.get(); !comment_str.empty()) comments.emplace(col_decl.name, comment_str); } } @@ -526,7 +526,7 @@ void InterpreterCreateQuery::setEngine(ASTCreateQuery & create) const String as_table_name = create.as_table; ASTPtr as_create_ptr = context.getCreateTableQuery(as_database_name, as_table_name); - const auto & as_create = typeid_cast(*as_create_ptr); + const auto & as_create = as_create_ptr->as(); if (as_create.is_view) throw Exception( @@ -566,8 +566,7 @@ BlockIO InterpreterCreateQuery::createTable(ASTCreateQuery & create) { // Table SQL definition is available even if the table is detached auto query = context.getCreateTableQuery(database_name, table_name); - auto & as_create = typeid_cast(*query); - create = as_create; // Copy the saved create query, but use ATTACH instead of CREATE + create = query->as(); // Copy the saved create query, but use ATTACH instead of CREATE create.attach = true; } @@ -695,7 +694,7 @@ BlockIO InterpreterCreateQuery::createTable(ASTCreateQuery & create) BlockIO InterpreterCreateQuery::execute() { - ASTCreateQuery & create = typeid_cast(*query_ptr); + auto & create = query_ptr->as(); checkAccess(create); ASTQueryWithOutput::resetOutputASTIfExist(create); diff --git a/dbms/src/Interpreters/InterpreterDescribeQuery.cpp b/dbms/src/Interpreters/InterpreterDescribeQuery.cpp index dfca483f7ad..f91fb1c0a2a 100644 --- a/dbms/src/Interpreters/InterpreterDescribeQuery.cpp +++ b/dbms/src/Interpreters/InterpreterDescribeQuery.cpp @@ -58,7 +58,7 @@ Block InterpreterDescribeQuery::getSampleBlock() BlockInputStreamPtr InterpreterDescribeQuery::executeImpl() { - const ASTDescribeQuery & ast = typeid_cast(*query_ptr); + const auto & ast = query_ptr->as(); NamesAndTypesList columns; ColumnDefaults column_defaults; @@ -66,7 +66,7 @@ BlockInputStreamPtr InterpreterDescribeQuery::executeImpl() ColumnCodecs column_codecs; StoragePtr table; - auto table_expression = typeid_cast(ast.table_expression.get()); + const auto * table_expression = ast.table_expression->as(); if (table_expression->subquery) { @@ -76,7 +76,7 @@ BlockInputStreamPtr InterpreterDescribeQuery::executeImpl() { if (table_expression->table_function) { - auto table_function = typeid_cast(table_expression->table_function.get()); + const auto * table_function = table_expression->table_function->as(); /// Get the table function TableFunctionPtr table_function_ptr = TableFunctionFactory::instance().get(table_function->name, context); /// Run it and remember the result @@ -84,7 +84,7 @@ BlockInputStreamPtr InterpreterDescribeQuery::executeImpl() } else { - auto identifier = typeid_cast(table_expression->database_and_table_name.get()); + const auto * identifier = table_expression->database_and_table_name->as(); String database_name; String table_name; diff --git a/dbms/src/Interpreters/InterpreterDescribeQuery.h b/dbms/src/Interpreters/InterpreterDescribeQuery.h index fc0bea10f2d..4fafe61f229 100644 --- a/dbms/src/Interpreters/InterpreterDescribeQuery.h +++ b/dbms/src/Interpreters/InterpreterDescribeQuery.h @@ -1,14 +1,13 @@ #pragma once #include +#include namespace DB { class Context; -class IAST; -using ASTPtr = std::shared_ptr; /** Return names, types and other information about columns in specified table. diff --git a/dbms/src/Interpreters/InterpreterDropQuery.cpp b/dbms/src/Interpreters/InterpreterDropQuery.cpp index cc422c445fc..91213b6100e 100644 --- a/dbms/src/Interpreters/InterpreterDropQuery.cpp +++ b/dbms/src/Interpreters/InterpreterDropQuery.cpp @@ -31,7 +31,7 @@ InterpreterDropQuery::InterpreterDropQuery(const ASTPtr & query_ptr_, Context & BlockIO InterpreterDropQuery::execute() { - ASTDropQuery & drop = typeid_cast(*query_ptr); + auto & drop = query_ptr->as(); checkAccess(drop); diff --git a/dbms/src/Interpreters/InterpreterDropQuery.h b/dbms/src/Interpreters/InterpreterDropQuery.h index 986c73f8465..8ca91610cbb 100644 --- a/dbms/src/Interpreters/InterpreterDropQuery.h +++ b/dbms/src/Interpreters/InterpreterDropQuery.h @@ -1,16 +1,14 @@ #pragma once +#include #include #include -#include - +#include namespace DB { class Context; -class IAST; -using ASTPtr = std::shared_ptr; using DatabaseAndTable = std::pair; /** Allow to either drop table with all its data (DROP), diff --git a/dbms/src/Interpreters/InterpreterExistsQuery.cpp b/dbms/src/Interpreters/InterpreterExistsQuery.cpp index d718479e669..27c78fe430d 100644 --- a/dbms/src/Interpreters/InterpreterExistsQuery.cpp +++ b/dbms/src/Interpreters/InterpreterExistsQuery.cpp @@ -32,7 +32,7 @@ Block InterpreterExistsQuery::getSampleBlock() BlockInputStreamPtr InterpreterExistsQuery::executeImpl() { - const ASTExistsQuery & ast = typeid_cast(*query_ptr); + const auto & ast = query_ptr->as(); bool res = ast.temporary ? context.isExternalTableExist(ast.table) : context.isTableExist(ast.database, ast.table); return std::make_shared(Block{{ diff --git a/dbms/src/Interpreters/InterpreterExistsQuery.h b/dbms/src/Interpreters/InterpreterExistsQuery.h index b2050b443d8..1860e1d0aa9 100644 --- a/dbms/src/Interpreters/InterpreterExistsQuery.h +++ b/dbms/src/Interpreters/InterpreterExistsQuery.h @@ -1,14 +1,13 @@ #pragma once #include +#include namespace DB { class Context; -class IAST; -using ASTPtr = std::shared_ptr; /** Check that table exists. Return single row with single column "result" of type UInt8 and value 0 or 1. diff --git a/dbms/src/Interpreters/InterpreterExplainQuery.cpp b/dbms/src/Interpreters/InterpreterExplainQuery.cpp index be7a592ecb9..971de38d11a 100644 --- a/dbms/src/Interpreters/InterpreterExplainQuery.cpp +++ b/dbms/src/Interpreters/InterpreterExplainQuery.cpp @@ -39,7 +39,7 @@ Block InterpreterExplainQuery::getSampleBlock() BlockInputStreamPtr InterpreterExplainQuery::executeImpl() { - const ASTExplainQuery & ast = typeid_cast(*query); + const auto & ast = query->as(); Block sample_block = getSampleBlock(); MutableColumns res_columns = sample_block.cloneEmptyColumns(); diff --git a/dbms/src/Interpreters/InterpreterExplainQuery.h b/dbms/src/Interpreters/InterpreterExplainQuery.h index 4db796ad014..0d3b183857b 100644 --- a/dbms/src/Interpreters/InterpreterExplainQuery.h +++ b/dbms/src/Interpreters/InterpreterExplainQuery.h @@ -2,15 +2,12 @@ #include #include +#include namespace DB { -class IAST; -using ASTPtr = std::shared_ptr; - - /// Returns single row with explain results class InterpreterExplainQuery : public IInterpreter { diff --git a/dbms/src/Interpreters/InterpreterFactory.cpp b/dbms/src/Interpreters/InterpreterFactory.cpp index 5d1b259cc0d..b2497481361 100644 --- a/dbms/src/Interpreters/InterpreterFactory.cpp +++ b/dbms/src/Interpreters/InterpreterFactory.cpp @@ -80,95 +80,95 @@ std::unique_ptr InterpreterFactory::get(ASTPtr & query, Context & { ProfileEvents::increment(ProfileEvents::Query); - if (typeid_cast(query.get())) + if (query->as()) { /// This is internal part of ASTSelectWithUnionQuery. /// Even if there is SELECT without union, it is represented by ASTSelectWithUnionQuery with single ASTSelectQuery as a child. return std::make_unique(query, context, Names{}, stage); } - else if (typeid_cast(query.get())) + else if (query->as()) { ProfileEvents::increment(ProfileEvents::SelectQuery); return std::make_unique(query, context, Names{}, stage); } - else if (typeid_cast(query.get())) + else if (query->as()) { ProfileEvents::increment(ProfileEvents::InsertQuery); /// readonly is checked inside InterpreterInsertQuery bool allow_materialized = static_cast(context.getSettingsRef().insert_allow_materialized_columns); return std::make_unique(query, context, allow_materialized); } - else if (typeid_cast(query.get())) + else if (query->as()) { /// readonly and allow_ddl are checked inside InterpreterCreateQuery return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { /// readonly and allow_ddl are checked inside InterpreterDropQuery return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { throwIfNoAccess(context); return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { /// readonly is checked inside InterpreterSetQuery return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { throwIfNoAccess(context); return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { throwIfNoAccess(context); return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { return std::make_unique(query, context); } - else if (typeid_cast(query.get())) + else if (query->as()) { throwIfNoAccess(context); return std::make_unique(query, context); diff --git a/dbms/src/Interpreters/InterpreterFactory.h b/dbms/src/Interpreters/InterpreterFactory.h index 25bad6659d5..1f065bbf69b 100644 --- a/dbms/src/Interpreters/InterpreterFactory.h +++ b/dbms/src/Interpreters/InterpreterFactory.h @@ -2,14 +2,13 @@ #include #include +#include namespace DB { class Context; -class IAST; -using ASTPtr = std::shared_ptr; class InterpreterFactory diff --git a/dbms/src/Interpreters/InterpreterInsertQuery.cpp b/dbms/src/Interpreters/InterpreterInsertQuery.cpp index d5c2600eda4..d5b67e5518a 100644 --- a/dbms/src/Interpreters/InterpreterInsertQuery.cpp +++ b/dbms/src/Interpreters/InterpreterInsertQuery.cpp @@ -46,7 +46,7 @@ StoragePtr InterpreterInsertQuery::getTable(const ASTInsertQuery & query) { if (query.table_function) { - auto table_function = typeid_cast(query.table_function.get()); + const auto * table_function = query.table_function->as(); const auto & factory = TableFunctionFactory::instance(); return factory.get(table_function->name, context)->execute(query.table_function, context); } @@ -92,7 +92,7 @@ Block InterpreterInsertQuery::getSampleBlock(const ASTInsertQuery & query, const BlockIO InterpreterInsertQuery::execute() { - ASTInsertQuery & query = typeid_cast(*query_ptr); + const auto & query = query_ptr->as(); checkAccess(query); StoragePtr table = getTable(query); @@ -171,7 +171,7 @@ void InterpreterInsertQuery::checkAccess(const ASTInsertQuery & query) std::pair InterpreterInsertQuery::getDatabaseTable() const { - ASTInsertQuery & query = typeid_cast(*query_ptr); + const auto & query = query_ptr->as(); return {query.database, query.table}; } diff --git a/dbms/src/Interpreters/InterpreterKillQueryQuery.cpp b/dbms/src/Interpreters/InterpreterKillQueryQuery.cpp index 66f6248f672..89339668088 100644 --- a/dbms/src/Interpreters/InterpreterKillQueryQuery.cpp +++ b/dbms/src/Interpreters/InterpreterKillQueryQuery.cpp @@ -172,7 +172,7 @@ public: BlockIO InterpreterKillQueryQuery::execute() { - ASTKillQueryQuery & query = typeid_cast(*query_ptr); + const auto & query = query_ptr->as(); if (!query.cluster.empty()) return executeDDLQueryOnCluster(query_ptr, context, {"system"}); @@ -261,7 +261,7 @@ BlockIO InterpreterKillQueryQuery::execute() Block InterpreterKillQueryQuery::getSelectResult(const String & columns, const String & table) { String select_query = "SELECT " + columns + " FROM " + table; - auto & where_expression = static_cast(*query_ptr).where_expression; + auto & where_expression = query_ptr->as()->where_expression; if (where_expression) select_query += " WHERE " + queryToString(where_expression); diff --git a/dbms/src/Interpreters/InterpreterKillQueryQuery.h b/dbms/src/Interpreters/InterpreterKillQueryQuery.h index 9294e45eab8..fab4b304865 100644 --- a/dbms/src/Interpreters/InterpreterKillQueryQuery.h +++ b/dbms/src/Interpreters/InterpreterKillQueryQuery.h @@ -1,14 +1,13 @@ #pragma once #include +#include namespace DB { class Context; -class IAST; -using ASTPtr = std::shared_ptr; class InterpreterKillQueryQuery : public IInterpreter @@ -28,4 +27,3 @@ private: } - diff --git a/dbms/src/Interpreters/InterpreterOptimizeQuery.cpp b/dbms/src/Interpreters/InterpreterOptimizeQuery.cpp index 47e77172eae..7326bb62924 100644 --- a/dbms/src/Interpreters/InterpreterOptimizeQuery.cpp +++ b/dbms/src/Interpreters/InterpreterOptimizeQuery.cpp @@ -17,7 +17,7 @@ namespace ErrorCodes BlockIO InterpreterOptimizeQuery::execute() { - const ASTOptimizeQuery & ast = typeid_cast(*query_ptr); + const auto & ast = query_ptr->as(); if (!ast.cluster.empty()) return executeDDLQueryOnCluster(query_ptr, context, {ast.database}); diff --git a/dbms/src/Interpreters/InterpreterOptimizeQuery.h b/dbms/src/Interpreters/InterpreterOptimizeQuery.h index 03c369b0e4a..251d8ea02f3 100644 --- a/dbms/src/Interpreters/InterpreterOptimizeQuery.h +++ b/dbms/src/Interpreters/InterpreterOptimizeQuery.h @@ -1,14 +1,13 @@ #pragma once #include +#include namespace DB { class Context; -class IAST; -using ASTPtr = std::shared_ptr; /** Just call method "optimize" for table. diff --git a/dbms/src/Interpreters/InterpreterRenameQuery.cpp b/dbms/src/Interpreters/InterpreterRenameQuery.cpp index 77a0c862905..360adf45194 100644 --- a/dbms/src/Interpreters/InterpreterRenameQuery.cpp +++ b/dbms/src/Interpreters/InterpreterRenameQuery.cpp @@ -36,7 +36,7 @@ struct RenameDescription BlockIO InterpreterRenameQuery::execute() { - ASTRenameQuery & rename = typeid_cast(*query_ptr); + const auto & rename = query_ptr->as(); if (!rename.cluster.empty()) { diff --git a/dbms/src/Interpreters/InterpreterRenameQuery.h b/dbms/src/Interpreters/InterpreterRenameQuery.h index e9a36277ef0..0433c638468 100644 --- a/dbms/src/Interpreters/InterpreterRenameQuery.h +++ b/dbms/src/Interpreters/InterpreterRenameQuery.h @@ -1,14 +1,13 @@ #pragma once #include +#include namespace DB { class Context; -class IAST; -using ASTPtr = std::shared_ptr; /** Rename one table diff --git a/dbms/src/Interpreters/InterpreterSelectQuery.cpp b/dbms/src/Interpreters/InterpreterSelectQuery.cpp index 83d53ca37b8..7dcd70dfc39 100644 --- a/dbms/src/Interpreters/InterpreterSelectQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSelectQuery.cpp @@ -23,7 +23,6 @@ #include #include -#include #include #include #include @@ -169,7 +168,7 @@ InterpreterSelectQuery::InterpreterSelectQuery( } max_streams = settings.max_threads; - ASTSelectQuery & query = selectQuery(); + auto & query = getSelectQuery(); ASTPtr table_expression = extractTableExpression(query, 0); @@ -177,8 +176,8 @@ InterpreterSelectQuery::InterpreterSelectQuery( bool is_subquery = false; if (table_expression) { - is_table_func = typeid_cast(table_expression.get()); - is_subquery = typeid_cast(table_expression.get()); + is_table_func = table_expression->as(); + is_subquery = table_expression->as(); } if (input) @@ -277,15 +276,9 @@ InterpreterSelectQuery::InterpreterSelectQuery( } -ASTSelectQuery & InterpreterSelectQuery::selectQuery() -{ - return typeid_cast(*query_ptr); -} - - void InterpreterSelectQuery::getDatabaseAndTableNames(String & database_name, String & table_name) { - if (auto db_and_table = getDatabaseAndTable(selectQuery(), 0)) + if (auto db_and_table = getDatabaseAndTable(getSelectQuery(), 0)) { table_name = db_and_table->table; database_name = db_and_table->database; @@ -384,7 +377,7 @@ InterpreterSelectQuery::AnalysisResult InterpreterSelectQuery::analyzeExpression { ExpressionActionsChain chain(context); - ASTSelectQuery & query = selectQuery(); + auto & query = getSelectQuery(); Names additional_required_columns_after_prewhere; @@ -508,7 +501,8 @@ void InterpreterSelectQuery::executeImpl(Pipeline & pipeline, const BlockInputSt * then perform the remaining operations with one resulting stream. */ - ASTSelectQuery & query = selectQuery(); + /// Now we will compose block streams that perform the necessary actions. + auto & query = getSelectQuery(); const Settings & settings = context.getSettingsRef(); QueryProcessingStage::Enum from_stage = QueryProcessingStage::FetchColumns; @@ -570,8 +564,6 @@ void InterpreterSelectQuery::executeImpl(Pipeline & pipeline, const BlockInputSt if (to_stage > QueryProcessingStage::FetchColumns) { - /// Now we will compose block streams that perform the necessary actions. - /// Do I need to aggregate in a separate row rows that have not passed max_rows_to_group_by. bool aggregate_overflow_row = expressions.need_aggregate && @@ -590,7 +582,7 @@ void InterpreterSelectQuery::executeImpl(Pipeline & pipeline, const BlockInputSt { if (expressions.hasJoin()) { - const ASTTableJoin & join = static_cast(*query.join()->table_join); + const auto & join = query.join()->table_join->as(); if (isRightOrFull(join.kind)) pipeline.stream_with_non_joined_data = expressions.before_join->createStreamWithNonJoinedDataIfFullOrRightJoin( pipeline.firstStream()->getHeader(), settings.max_block_size); @@ -786,7 +778,7 @@ static std::pair getLimitLengthAndOffset(const ASTSelectQuery & return {length, offset}; } -static UInt64 getLimitForSorting(ASTSelectQuery & query, const Context & context) +static UInt64 getLimitForSorting(const ASTSelectQuery & query, const Context & context) { /// Partial sort can be done if there is LIMIT but no DISTINCT or LIMIT BY. if (!query.distinct && !query.limit_by_expression_list) @@ -802,7 +794,7 @@ void InterpreterSelectQuery::executeFetchColumns( QueryProcessingStage::Enum processing_stage, Pipeline & pipeline, const PrewhereInfoPtr & prewhere_info, const Names & columns_to_remove_after_prewhere) { - ASTSelectQuery & query = selectQuery(); + auto & query = getSelectQuery(); const Settings & settings = context.getSettingsRef(); /// Actions to calculate ALIAS if required. @@ -1097,7 +1089,7 @@ void InterpreterSelectQuery::executeWhere(Pipeline & pipeline, const ExpressionA { pipeline.transform([&](auto & stream) { - stream = std::make_shared(stream, expression, selectQuery().where_expression->getColumnName(), remove_fiter); + stream = std::make_shared(stream, expression, getSelectQuery().where_expression->getColumnName(), remove_fiter); }); } @@ -1225,7 +1217,7 @@ void InterpreterSelectQuery::executeHaving(Pipeline & pipeline, const Expression { pipeline.transform([&](auto & stream) { - stream = std::make_shared(stream, expression, selectQuery().having_expression->getColumnName()); + stream = std::make_shared(stream, expression, getSelectQuery().having_expression->getColumnName()); }); } @@ -1237,8 +1229,13 @@ void InterpreterSelectQuery::executeTotalsAndHaving(Pipeline & pipeline, bool ha const Settings & settings = context.getSettingsRef(); pipeline.firstStream() = std::make_shared( - pipeline.firstStream(), overflow_row, expression, - has_having ? selectQuery().having_expression->getColumnName() : "", settings.totals_mode, settings.totals_auto_threshold, final); + pipeline.firstStream(), + overflow_row, + expression, + has_having ? getSelectQuery().having_expression->getColumnName() : "", + settings.totals_mode, + settings.totals_auto_threshold, + final); } void InterpreterSelectQuery::executeRollupOrCube(Pipeline & pipeline, Modificator modificator) @@ -1281,18 +1278,18 @@ void InterpreterSelectQuery::executeExpression(Pipeline & pipeline, const Expres } -static SortDescription getSortDescription(ASTSelectQuery & query) +static SortDescription getSortDescription(const ASTSelectQuery & query) { SortDescription order_descr; order_descr.reserve(query.order_expression_list->children.size()); for (const auto & elem : query.order_expression_list->children) { String name = elem->children.front()->getColumnName(); - const ASTOrderByElement & order_by_elem = typeid_cast(*elem); + const auto & order_by_elem = elem->as(); std::shared_ptr collator; if (order_by_elem.collation) - collator = std::make_shared(typeid_cast(*order_by_elem.collation).value.get()); + collator = std::make_shared(order_by_elem.collation->as().value.get()); order_descr.emplace_back(name, order_by_elem.direction, order_by_elem.nulls_direction, collator); } @@ -1303,7 +1300,7 @@ static SortDescription getSortDescription(ASTSelectQuery & query) void InterpreterSelectQuery::executeOrder(Pipeline & pipeline) { - ASTSelectQuery & query = selectQuery(); + auto & query = getSelectQuery(); SortDescription order_descr = getSortDescription(query); UInt64 limit = getLimitForSorting(query, context); @@ -1335,7 +1332,7 @@ void InterpreterSelectQuery::executeOrder(Pipeline & pipeline) void InterpreterSelectQuery::executeMergeSorted(Pipeline & pipeline) { - ASTSelectQuery & query = selectQuery(); + auto & query = getSelectQuery(); SortDescription order_descr = getSortDescription(query); UInt64 limit = getLimitForSorting(query, context); @@ -1372,7 +1369,7 @@ void InterpreterSelectQuery::executeProjection(Pipeline & pipeline, const Expres void InterpreterSelectQuery::executeDistinct(Pipeline & pipeline, bool before_order, Names columns) { - ASTSelectQuery & query = selectQuery(); + auto & query = getSelectQuery(); if (query.distinct) { const Settings & settings = context.getSettingsRef(); @@ -1415,7 +1412,7 @@ void InterpreterSelectQuery::executeUnion(Pipeline & pipeline) /// Preliminary LIMIT - is used in every source, if there are several sources, before they are combined. void InterpreterSelectQuery::executePreLimit(Pipeline & pipeline) { - ASTSelectQuery & query = selectQuery(); + auto & query = getSelectQuery(); /// If there is LIMIT if (query.limit_length) { @@ -1430,7 +1427,7 @@ void InterpreterSelectQuery::executePreLimit(Pipeline & pipeline) void InterpreterSelectQuery::executeLimitBy(Pipeline & pipeline) { - ASTSelectQuery & query = selectQuery(); + auto & query = getSelectQuery(); if (!query.limit_by_value || !query.limit_by_expression_list) return; @@ -1458,10 +1455,10 @@ bool hasWithTotalsInAnySubqueryInFromClause(const ASTSelectQuery & query) if (auto query_table = extractTableExpression(query, 0)) { - if (auto ast_union = typeid_cast(query_table.get())) + if (const auto * ast_union = query_table->as()) { for (const auto & elem : ast_union->list_of_selects->children) - if (hasWithTotalsInAnySubqueryInFromClause(typeid_cast(*elem))) + if (hasWithTotalsInAnySubqueryInFromClause(elem->as())) return true; } } @@ -1472,7 +1469,7 @@ bool hasWithTotalsInAnySubqueryInFromClause(const ASTSelectQuery & query) void InterpreterSelectQuery::executeLimit(Pipeline & pipeline) { - ASTSelectQuery & query = selectQuery(); + auto & query = getSelectQuery(); /// If there is LIMIT if (query.limit_length) { @@ -1544,13 +1541,13 @@ void InterpreterSelectQuery::unifyStreams(Pipeline & pipeline) void InterpreterSelectQuery::ignoreWithTotals() { - selectQuery().group_by_with_totals = false; + getSelectQuery().group_by_with_totals = false; } void InterpreterSelectQuery::initSettings() { - ASTSelectQuery & query = selectQuery(); + auto & query = getSelectQuery(); if (query.settings) InterpreterSetQuery(query.settings, context).executeForCurrentContext(); } diff --git a/dbms/src/Interpreters/InterpreterSelectQuery.h b/dbms/src/Interpreters/InterpreterSelectQuery.h index 89fdc35eb7b..fa4651c12ff 100644 --- a/dbms/src/Interpreters/InterpreterSelectQuery.h +++ b/dbms/src/Interpreters/InterpreterSelectQuery.h @@ -3,11 +3,12 @@ #include #include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include #include @@ -16,7 +17,6 @@ namespace Poco { class Logger; } namespace DB { -class ASTSelectQuery; struct SubqueryForSet; class InterpreterSelectWithUnionQuery; @@ -99,6 +99,8 @@ private: bool only_analyze_, bool modify_inplace); + ASTSelectQuery & getSelectQuery() { return query_ptr->as(); } + struct Pipeline { @@ -133,7 +135,6 @@ private: } }; - ASTSelectQuery & selectQuery(); void executeImpl(Pipeline & pipeline, const BlockInputStreamPtr & prepared_input, bool dry_run); diff --git a/dbms/src/Interpreters/InterpreterSelectWithUnionQuery.cpp b/dbms/src/Interpreters/InterpreterSelectWithUnionQuery.cpp index 1dc5419223e..34918023d15 100644 --- a/dbms/src/Interpreters/InterpreterSelectWithUnionQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSelectWithUnionQuery.cpp @@ -36,7 +36,7 @@ InterpreterSelectWithUnionQuery::InterpreterSelectWithUnionQuery( to_stage(to_stage_), subquery_depth(subquery_depth_) { - const ASTSelectWithUnionQuery & ast = typeid_cast(*query_ptr); + const auto & ast = query_ptr->as(); size_t num_selects = ast.list_of_selects->children.size(); diff --git a/dbms/src/Interpreters/InterpreterSetQuery.cpp b/dbms/src/Interpreters/InterpreterSetQuery.cpp index bd69ff2ce56..08b5bc5a620 100644 --- a/dbms/src/Interpreters/InterpreterSetQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSetQuery.cpp @@ -16,7 +16,7 @@ namespace ErrorCodes BlockIO InterpreterSetQuery::execute() { - const ASTSetQuery & ast = typeid_cast(*query_ptr); + const auto & ast = query_ptr->as(); checkAccess(ast); @@ -61,7 +61,7 @@ void InterpreterSetQuery::checkAccess(const ASTSetQuery & ast) void InterpreterSetQuery::executeForCurrentContext() { - const ASTSetQuery & ast = typeid_cast(*query_ptr); + const auto & ast = query_ptr->as(); checkAccess(ast); diff --git a/dbms/src/Interpreters/InterpreterSetQuery.h b/dbms/src/Interpreters/InterpreterSetQuery.h index 7dbd444f6f1..434765cf7d0 100644 --- a/dbms/src/Interpreters/InterpreterSetQuery.h +++ b/dbms/src/Interpreters/InterpreterSetQuery.h @@ -1,15 +1,14 @@ #pragma once #include +#include namespace DB { class Context; -class IAST; class ASTSetQuery; -using ASTPtr = std::shared_ptr; /** Change one or several settings for the session or just for the current context. diff --git a/dbms/src/Interpreters/InterpreterShowCreateQuery.cpp b/dbms/src/Interpreters/InterpreterShowCreateQuery.cpp index 852bf45d720..74299ffaf4a 100644 --- a/dbms/src/Interpreters/InterpreterShowCreateQuery.cpp +++ b/dbms/src/Interpreters/InterpreterShowCreateQuery.cpp @@ -42,6 +42,7 @@ Block InterpreterShowCreateQuery::getSampleBlock() BlockInputStreamPtr InterpreterShowCreateQuery::executeImpl() { + /// FIXME: try to prettify this cast using `as<>()` const auto & ast = dynamic_cast(*query_ptr); if (ast.temporary && !ast.database.empty()) diff --git a/dbms/src/Interpreters/InterpreterShowCreateQuery.h b/dbms/src/Interpreters/InterpreterShowCreateQuery.h index 5e8672c1767..5ac98509b23 100644 --- a/dbms/src/Interpreters/InterpreterShowCreateQuery.h +++ b/dbms/src/Interpreters/InterpreterShowCreateQuery.h @@ -1,14 +1,13 @@ #pragma once #include +#include namespace DB { class Context; -class IAST; -using ASTPtr = std::shared_ptr; /** Return single row with single column "statement" of type String with text of query to CREATE specified table. diff --git a/dbms/src/Interpreters/InterpreterShowProcesslistQuery.h b/dbms/src/Interpreters/InterpreterShowProcesslistQuery.h index ee180e76699..6b87fd7edc3 100644 --- a/dbms/src/Interpreters/InterpreterShowProcesslistQuery.h +++ b/dbms/src/Interpreters/InterpreterShowProcesslistQuery.h @@ -1,14 +1,13 @@ #pragma once #include +#include namespace DB { class Context; -class IAST; -using ASTPtr = std::shared_ptr; /** Return list of currently executing queries. diff --git a/dbms/src/Interpreters/InterpreterShowTablesQuery.cpp b/dbms/src/Interpreters/InterpreterShowTablesQuery.cpp index ab15d1f0112..774edcc3390 100644 --- a/dbms/src/Interpreters/InterpreterShowTablesQuery.cpp +++ b/dbms/src/Interpreters/InterpreterShowTablesQuery.cpp @@ -25,7 +25,7 @@ InterpreterShowTablesQuery::InterpreterShowTablesQuery(const ASTPtr & query_ptr_ String InterpreterShowTablesQuery::getRewrittenQuery() { - const ASTShowTablesQuery & query = typeid_cast(*query_ptr); + const auto & query = query_ptr->as(); /// SHOW DATABASES if (query.databases) diff --git a/dbms/src/Interpreters/InterpreterShowTablesQuery.h b/dbms/src/Interpreters/InterpreterShowTablesQuery.h index 3f661da3e2d..fc5cb2b7505 100644 --- a/dbms/src/Interpreters/InterpreterShowTablesQuery.h +++ b/dbms/src/Interpreters/InterpreterShowTablesQuery.h @@ -1,14 +1,13 @@ #pragma once #include +#include namespace DB { class Context; -class IAST; -using ASTPtr = std::shared_ptr; /** Return a list of tables or databases meets specified conditions. diff --git a/dbms/src/Interpreters/InterpreterSystemQuery.cpp b/dbms/src/Interpreters/InterpreterSystemQuery.cpp index 20bd860fb26..6bb0b3474fc 100644 --- a/dbms/src/Interpreters/InterpreterSystemQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSystemQuery.cpp @@ -117,7 +117,7 @@ InterpreterSystemQuery::InterpreterSystemQuery(const ASTPtr & query_ptr_, Contex BlockIO InterpreterSystemQuery::execute() { - auto & query = typeid_cast(*query_ptr); + auto & query = query_ptr->as(); using Type = ASTSystemQuery::Type; @@ -248,7 +248,7 @@ StoragePtr InterpreterSystemQuery::tryRestartReplica(const String & database_nam /// Attach actions { /// getCreateTableQuery must return canonical CREATE query representation, there are no need for AST postprocessing - auto & create = typeid_cast(*create_ast); + auto & create = create_ast->as(); create.attach = true; std::string data_path = database->getDataPath(); diff --git a/dbms/src/Interpreters/InterpreterSystemQuery.h b/dbms/src/Interpreters/InterpreterSystemQuery.h index fb92e799761..65aaf789419 100644 --- a/dbms/src/Interpreters/InterpreterSystemQuery.h +++ b/dbms/src/Interpreters/InterpreterSystemQuery.h @@ -1,15 +1,15 @@ #pragma once + #include +#include namespace DB { class Context; -class IAST; class ASTSystemQuery; class IStorage; -using ASTPtr = std::shared_ptr; using StoragePtr = std::shared_ptr; diff --git a/dbms/src/Interpreters/InterpreterUseQuery.cpp b/dbms/src/Interpreters/InterpreterUseQuery.cpp index 8dba0d55223..d815d66aadc 100644 --- a/dbms/src/Interpreters/InterpreterUseQuery.cpp +++ b/dbms/src/Interpreters/InterpreterUseQuery.cpp @@ -9,7 +9,7 @@ namespace DB BlockIO InterpreterUseQuery::execute() { - const String & new_database = typeid_cast(*query_ptr).database; + const String & new_database = query_ptr->as().database; context.getSessionContext().setCurrentDatabase(new_database); return {}; } diff --git a/dbms/src/Interpreters/InterpreterUseQuery.h b/dbms/src/Interpreters/InterpreterUseQuery.h index 988ccb741fa..ae409117afd 100644 --- a/dbms/src/Interpreters/InterpreterUseQuery.h +++ b/dbms/src/Interpreters/InterpreterUseQuery.h @@ -1,14 +1,13 @@ #pragma once #include +#include namespace DB { class Context; -class IAST; -using ASTPtr = std::shared_ptr; /** Change default database for session. diff --git a/dbms/src/Interpreters/JoinToSubqueryTransformVisitor.cpp b/dbms/src/Interpreters/JoinToSubqueryTransformVisitor.cpp index 1a110a6c8e0..a99e79b1120 100644 --- a/dbms/src/Interpreters/JoinToSubqueryTransformVisitor.cpp +++ b/dbms/src/Interpreters/JoinToSubqueryTransformVisitor.cpp @@ -90,18 +90,17 @@ struct ColumnAliasesMatcher static bool needChildVisit(ASTPtr & node, const ASTPtr &) { - if (typeid_cast(node.get())) + if (node->as()) return false; return true; } static void visit(ASTPtr & ast, Data & data) { - if (auto * t = typeid_cast(ast.get())) + if (auto * t = ast->as()) visit(*t, ast, data); - if (typeid_cast(ast.get()) || - typeid_cast(ast.get())) + if (ast->as() || ast->as()) throw Exception("Multiple JOIN do not support asterisks yet", ErrorCodes::NOT_IMPLEMENTED); } @@ -160,9 +159,9 @@ struct AppendSemanticVisitorData for (auto & child : select.select_expression_list->children) { - if (auto * node = typeid_cast(child.get())) + if (auto * node = child->as()) AsteriskSemantic::setAliases(*node, rev_aliases); - if (auto * node = typeid_cast(child.get())) + if (auto * node = child->as()) AsteriskSemantic::setAliases(*node, rev_aliases); } @@ -196,7 +195,7 @@ bool needRewrite(ASTSelectQuery & select) if (!select.tables) return false; - auto tables = typeid_cast(select.tables.get()); + const auto * tables = select.tables->as(); if (!tables) return false; @@ -206,11 +205,11 @@ bool needRewrite(ASTSelectQuery & select) for (size_t i = 1; i < tables->children.size(); ++i) { - auto table = typeid_cast(tables->children[i].get()); + const auto * table = tables->children[i]->as(); if (!table || !table->table_join) throw Exception("Multiple JOIN expects joined tables", ErrorCodes::LOGICAL_ERROR); - auto join = typeid_cast(*table->table_join); + const auto & join = table->table_join->as(); if (isComma(join.kind)) throw Exception("COMMA to CROSS JOIN rewriter is not enabled or cannot rewrite query", ErrorCodes::NOT_IMPLEMENTED); @@ -233,7 +232,7 @@ using AppendSemanticVisitor = InDepthNodeVisitor; void JoinToSubqueryTransformMatcher::visit(ASTPtr & ast, Data & data) { - if (auto * t = typeid_cast(ast.get())) + if (auto * t = ast->as()) visit(*t, ast, data); } @@ -261,10 +260,10 @@ void JoinToSubqueryTransformMatcher::visit(ASTSelectQuery & select, ASTPtr &, Da /// JOIN sections for (auto & child : select.tables->children) { - auto table = typeid_cast(child.get()); + auto * table = child->as(); if (table->table_join) { - auto & join = typeid_cast(*table->table_join); + auto & join = table->table_join->as(); if (join.on_expression) ColumnAliasesVisitor(aliases_data).visit(join.on_expression); } @@ -307,8 +306,8 @@ static ASTPtr makeSubqueryTemplate() ASTPtr JoinToSubqueryTransformMatcher::replaceJoin(ASTPtr ast_left, ASTPtr ast_right) { - auto left = typeid_cast(ast_left.get()); - auto right = typeid_cast(ast_right.get()); + const auto * left = ast_left->as(); + const auto * right = ast_right->as(); if (!left || !right) throw Exception("Two TablesInSelectQueryElements expected", ErrorCodes::LOGICAL_ERROR); diff --git a/dbms/src/Interpreters/LogicalExpressionsOptimizer.cpp b/dbms/src/Interpreters/LogicalExpressionsOptimizer.cpp index d7f04ff2a25..b0ddd669d62 100644 --- a/dbms/src/Interpreters/LogicalExpressionsOptimizer.cpp +++ b/dbms/src/Interpreters/LogicalExpressionsOptimizer.cpp @@ -19,7 +19,7 @@ namespace ErrorCodes } -LogicalExpressionsOptimizer::OrWithExpression::OrWithExpression(ASTFunction * or_function_, +LogicalExpressionsOptimizer::OrWithExpression::OrWithExpression(const ASTFunction * or_function_, const IAST::Hash & expression_, const std::string & alias_) : or_function(or_function_), expression(expression_), alias(alias_) { @@ -111,24 +111,24 @@ void LogicalExpressionsOptimizer::collectDisjunctiveEqualityChains() bool found_chain = false; - auto function = typeid_cast(to_node); - if ((function != nullptr) && (function->name == "or") && (function->children.size() == 1)) + auto * function = to_node->as(); + if (function && function->name == "or" && function->children.size() == 1) { - auto expression_list = typeid_cast(&*(function->children[0])); - if (expression_list != nullptr) + const auto * expression_list = function->children[0]->as(); + if (expression_list) { /// The chain of elements of the OR expression. for (auto & child : expression_list->children) { - auto equals = typeid_cast(&*child); - if ((equals != nullptr) && (equals->name == "equals") && (equals->children.size() == 1)) + auto * equals = child->as(); + if (equals && equals->name == "equals" && equals->children.size() == 1) { - auto equals_expression_list = typeid_cast(&*(equals->children[0])); - if ((equals_expression_list != nullptr) && (equals_expression_list->children.size() == 2)) + const auto * equals_expression_list = equals->children[0]->as(); + if (equals_expression_list && equals_expression_list->children.size() == 2) { /// Equality expr = xN. - auto literal = typeid_cast(&*(equals_expression_list->children[1])); - if (literal != nullptr) + const auto * literal = equals_expression_list->children[1]->as(); + if (literal) { auto expr_lhs = equals_expression_list->children[0]->getTreeHash(); OrWithExpression or_with_expression{function, expr_lhs, function->tryGetAlias()}; @@ -157,7 +157,7 @@ void LogicalExpressionsOptimizer::collectDisjunctiveEqualityChains() { for (auto & child : to_node->children) { - if (typeid_cast(child.get()) == nullptr) + if (!child->as()) { if (!visited_nodes.count(child.get())) to_visit.push_back(Edge(to_node, &*child)); @@ -187,10 +187,9 @@ void LogicalExpressionsOptimizer::collectDisjunctiveEqualityChains() namespace { -inline ASTs & getFunctionOperands(ASTFunction * or_function) +inline ASTs & getFunctionOperands(const ASTFunction * or_function) { - auto expression_list = static_cast(&*(or_function->children[0])); - return expression_list->children; + return or_function->children[0]->children; } } @@ -206,11 +205,11 @@ bool LogicalExpressionsOptimizer::mayOptimizeDisjunctiveEqualityChain(const Disj /// We check that the right-hand sides of all equalities have the same type. auto & first_operands = getFunctionOperands(equality_functions[0]); - auto first_literal = static_cast(&*first_operands[1]); + const auto * first_literal = first_operands[1]->as(); for (size_t i = 1; i < equality_functions.size(); ++i) { auto & operands = getFunctionOperands(equality_functions[i]); - auto literal = static_cast(&*operands[1]); + const auto * literal = operands[1]->as(); if (literal->value.getType() != first_literal->value.getType()) return false; @@ -238,8 +237,8 @@ void LogicalExpressionsOptimizer::addInExpression(const DisjunctiveEqualityChain /// Otherwise, they would be specified in the order of the ASTLiteral addresses, which is nondeterministic. std::sort(value_list->children.begin(), value_list->children.end(), [](const DB::ASTPtr & lhs, const DB::ASTPtr & rhs) { - const auto val_lhs = static_cast(&*lhs); - const auto val_rhs = static_cast(&*rhs); + const auto * val_lhs = lhs->as(); + const auto * val_rhs = rhs->as(); return val_lhs->value < val_rhs->value; }); @@ -277,7 +276,7 @@ void LogicalExpressionsOptimizer::cleanupOrExpressions() { /// Saves for each optimized OR-chain the iterator on the first element /// list of operands to be deleted. - std::unordered_map garbage_map; + std::unordered_map garbage_map; /// Initialization. garbage_map.reserve(processed_count); diff --git a/dbms/src/Interpreters/LogicalExpressionsOptimizer.h b/dbms/src/Interpreters/LogicalExpressionsOptimizer.h index 636c83e1d9f..09c3931ce1d 100644 --- a/dbms/src/Interpreters/LogicalExpressionsOptimizer.h +++ b/dbms/src/Interpreters/LogicalExpressionsOptimizer.h @@ -51,11 +51,10 @@ private: */ struct OrWithExpression { - OrWithExpression(ASTFunction * or_function_, const IAST::Hash & expression_, - const std::string & alias_); + OrWithExpression(const ASTFunction * or_function_, const IAST::Hash & expression_, const std::string & alias_); bool operator<(const OrWithExpression & rhs) const; - ASTFunction * or_function; + const ASTFunction * or_function; const IAST::Hash expression; const std::string alias; }; @@ -95,8 +94,8 @@ private: private: using ParentNodes = std::vector; - using FunctionParentMap = std::unordered_map; - using ColumnToPosition = std::unordered_map; + using FunctionParentMap = std::unordered_map; + using ColumnToPosition = std::unordered_map; private: ASTSelectQuery * select_query; diff --git a/dbms/src/Interpreters/OptimizeIfWithConstantConditionVisitor.cpp b/dbms/src/Interpreters/OptimizeIfWithConstantConditionVisitor.cpp index e73a734ab16..dd63093493f 100644 --- a/dbms/src/Interpreters/OptimizeIfWithConstantConditionVisitor.cpp +++ b/dbms/src/Interpreters/OptimizeIfWithConstantConditionVisitor.cpp @@ -16,7 +16,7 @@ namespace ErrorCodes static bool tryExtractConstValueFromCondition(const ASTPtr & condition, bool & value) { /// numeric constant in condition - if (const ASTLiteral * literal = typeid_cast(condition.get())) + if (const auto * literal = condition->as()) { if (literal->value.getType() == Field::Types::Int64 || literal->value.getType() == Field::Types::UInt64) @@ -27,14 +27,14 @@ static bool tryExtractConstValueFromCondition(const ASTPtr & condition, bool & v } /// cast of numeric constant in condition to UInt8 - if (const ASTFunction * function = typeid_cast(condition.get())) + if (const auto * function = condition->as()) { if (function->name == "CAST") { - if (ASTExpressionList * expr_list = typeid_cast(function->arguments.get())) + if (const auto * expr_list = function->arguments->as()) { const ASTPtr & type_ast = expr_list->children.at(1); - if (const ASTLiteral * type_literal = typeid_cast(type_ast.get())) + if (const auto * type_literal = type_ast->as()) { if (type_literal->value.getType() == Field::Types::String && type_literal->value.get() == "UInt8") @@ -54,7 +54,7 @@ void OptimizeIfWithConstantConditionVisitor::visit(ASTPtr & current_ast) for (ASTPtr & child : current_ast->children) { - auto * function_node = typeid_cast(child.get()); + auto * function_node = child->as(); if (!function_node || function_node->name != "if") { visit(child); @@ -62,7 +62,7 @@ void OptimizeIfWithConstantConditionVisitor::visit(ASTPtr & current_ast) } visit(function_node->arguments); - auto * args = typeid_cast(function_node->arguments.get()); + const auto * args = function_node->arguments->as(); if (args->children.size() != 3) throw Exception("Wrong number of arguments for function 'if' (" + toString(args->children.size()) + " instead of 3)", diff --git a/dbms/src/Interpreters/PredicateExpressionsOptimizer.cpp b/dbms/src/Interpreters/PredicateExpressionsOptimizer.cpp index 612ea231bdd..b564c2cd52d 100644 --- a/dbms/src/Interpreters/PredicateExpressionsOptimizer.cpp +++ b/dbms/src/Interpreters/PredicateExpressionsOptimizer.cpp @@ -151,7 +151,7 @@ std::vector PredicateExpressionsOptimizer::splitConjunctionPredicate(AST { const auto expression = predicate_expressions.at(idx); - if (const auto function = typeid_cast(expression.get())) + if (const auto * function = expression->as()) { if (function->name == and_function_name) { @@ -239,7 +239,7 @@ void PredicateExpressionsOptimizer::setNewAliasesForInnerPredicate( if (alias == qualified_name) { String name; - if (auto * id = typeid_cast(ast.get())) + if (auto * id = ast->as()) { name = id->tryGetAlias(); if (name.empty()) @@ -260,7 +260,7 @@ void PredicateExpressionsOptimizer::setNewAliasesForInnerPredicate( bool PredicateExpressionsOptimizer::isArrayJoinFunction(const ASTPtr & node) { - if (auto function = typeid_cast(node.get())) + if (const auto * function = node->as()) { if (function->name == "arrayJoin") return true; @@ -309,7 +309,7 @@ void PredicateExpressionsOptimizer::getSubqueryProjectionColumns(const ASTPtr & const ASTPtr & subselect = subquery->children[0]; ASTs select_with_union_projections; - auto select_with_union_query = static_cast(subselect.get()); + const auto * select_with_union_query = subselect->as(); for (auto & select : select_with_union_query->list_of_selects->children) { @@ -325,7 +325,7 @@ void PredicateExpressionsOptimizer::getSubqueryProjectionColumns(const ASTPtr & subquery_projections.emplace_back(std::pair(select_projection_columns[i], qualified_name_prefix + select_with_union_projections[i]->getAliasOrColumnName())); - projection_columns.insert(std::pair(static_cast(select.get()), subquery_projections)); + projection_columns.insert(std::pair(select->as(), subquery_projections)); } } } @@ -333,7 +333,7 @@ void PredicateExpressionsOptimizer::getSubqueryProjectionColumns(const ASTPtr & ASTs PredicateExpressionsOptimizer::getSelectQueryProjectionColumns(ASTPtr & ast) { ASTs projection_columns; - auto select_query = static_cast(ast.get()); + auto * select_query = ast->as(); /// first should normalize query tree. std::unordered_map aliases; @@ -352,7 +352,7 @@ ASTs PredicateExpressionsOptimizer::getSelectQueryProjectionColumns(ASTPtr & ast for (const auto & projection_column : select_query->select_expression_list->children) { - if (typeid_cast(projection_column.get()) || typeid_cast(projection_column.get())) + if (projection_column->as() || projection_column->as()) { ASTs evaluated_columns = evaluateAsterisk(select_query, projection_column); @@ -375,7 +375,7 @@ ASTs PredicateExpressionsOptimizer::evaluateAsterisk(ASTSelectQuery * select_que std::vector tables_expression = getSelectTablesExpression(*select_query); - if (const auto qualified_asterisk = typeid_cast(asterisk.get())) + if (const auto * qualified_asterisk = asterisk->as()) { if (qualified_asterisk->children.size() != 1) throw Exception("Logical error: qualified asterisk must have exactly one child", ErrorCodes::LOGICAL_ERROR); @@ -399,8 +399,8 @@ ASTs PredicateExpressionsOptimizer::evaluateAsterisk(ASTSelectQuery * select_que { if (table_expression->subquery) { - const auto subquery = static_cast(table_expression->subquery.get()); - const auto select_with_union_query = static_cast(subquery->children[0].get()); + const auto * subquery = table_expression->subquery->as(); + const auto * select_with_union_query = subquery->children[0]->as(); const auto subquery_projections = getSelectQueryProjectionColumns(select_with_union_query->list_of_selects->children[0]); projection_columns.insert(projection_columns.end(), subquery_projections.begin(), subquery_projections.end()); } @@ -415,7 +415,7 @@ ASTs PredicateExpressionsOptimizer::evaluateAsterisk(ASTSelectQuery * select_que } else if (table_expression->database_and_table_name) { - const auto database_and_table_ast = static_cast(table_expression->database_and_table_name.get()); + const auto * database_and_table_ast = table_expression->database_and_table_name->as(); DatabaseAndTableWithAlias database_and_table_name(*database_and_table_ast); storage = context.getTable(database_and_table_name.database, database_and_table_name.table); } diff --git a/dbms/src/Interpreters/ProcessList.cpp b/dbms/src/Interpreters/ProcessList.cpp index f1cdc946771..007d77c649e 100644 --- a/dbms/src/Interpreters/ProcessList.cpp +++ b/dbms/src/Interpreters/ProcessList.cpp @@ -38,7 +38,7 @@ static bool isUnlimitedQuery(const IAST * ast) return false; /// It is KILL QUERY - if (typeid_cast(ast)) + if (ast->as()) return true; /// It is SELECT FROM system.processes @@ -46,12 +46,12 @@ static bool isUnlimitedQuery(const IAST * ast) /// False negative: USE system; SELECT * FROM processes; /// False positive: SELECT * FROM system.processes CROSS JOIN (SELECT ...) - if (auto ast_selects = typeid_cast(ast)) + if (const auto * ast_selects = ast->as()) { if (!ast_selects->list_of_selects || ast_selects->list_of_selects->children.empty()) return false; - auto ast_select = typeid_cast(ast_selects->list_of_selects->children[0].get()); + const auto * ast_select = ast_selects->list_of_selects->children[0]->as(); if (!ast_select) return false; diff --git a/dbms/src/Interpreters/QueryAliasesVisitor.cpp b/dbms/src/Interpreters/QueryAliasesVisitor.cpp index 96ef4806676..f9257870583 100644 --- a/dbms/src/Interpreters/QueryAliasesVisitor.cpp +++ b/dbms/src/Interpreters/QueryAliasesVisitor.cpp @@ -32,18 +32,16 @@ static String wrongAliasMessage(const ASTPtr & ast, const ASTPtr & prev_ast, con bool QueryAliasesMatcher::needChildVisit(ASTPtr & node, const ASTPtr &) { /// Don't descent into table functions and subqueries and special case for ArrayJoin. - if (typeid_cast(node.get()) || - typeid_cast(node.get()) || - typeid_cast(node.get())) + if (node->as() || node->as() || node->as()) return false; return true; } void QueryAliasesMatcher::visit(ASTPtr & ast, Data & data) { - if (auto * s = typeid_cast(ast.get())) + if (auto * s = ast->as()) visit(*s, ast, data); - else if (auto * aj = typeid_cast(ast.get())) + else if (auto * aj = ast->as()) visit(*aj, ast, data); else visitOther(ast, data); diff --git a/dbms/src/Interpreters/QueryNormalizer.cpp b/dbms/src/Interpreters/QueryNormalizer.cpp index f689eff6555..f4f66f59dce 100644 --- a/dbms/src/Interpreters/QueryNormalizer.cpp +++ b/dbms/src/Interpreters/QueryNormalizer.cpp @@ -134,14 +134,14 @@ void QueryNormalizer::visit(ASTTablesInSelectQueryElement & node, const ASTPtr & /// mark table Identifiers as 'not a column' if (node.table_expression) { - auto & expr = static_cast(*node.table_expression); + auto & expr = node.table_expression->as(); setIdentifierSpecial(expr.database_and_table_name); } /// normalize JOIN ON section if (node.table_join) { - auto & join = static_cast(*node.table_join); + auto & join = node.table_join->as(); if (join.on_expression) visit(join.on_expression, data); } @@ -149,8 +149,7 @@ void QueryNormalizer::visit(ASTTablesInSelectQueryElement & node, const ASTPtr & static bool needVisitChild(const ASTPtr & child) { - if (typeid_cast(child.get()) || - typeid_cast(child.get())) + if (child->as() || child->as()) return false; return true; } @@ -178,7 +177,7 @@ void QueryNormalizer::visit(ASTSelectQuery & select, const ASTPtr & ast, Data & /// on aliases in expressions of the form 123 AS x, arrayMap(x -> 1, [2]). void QueryNormalizer::visitChildren(const ASTPtr & node, Data & data) { - if (ASTFunction * func_node = typeid_cast(node.get())) + if (const auto * func_node = node->as()) { /// We skip the first argument. We also assume that the lambda function can not have parameters. size_t first_pos = 0; @@ -195,7 +194,7 @@ void QueryNormalizer::visitChildren(const ASTPtr & node, Data & data) visit(child, data); } } - else if (!typeid_cast(node.get())) + else if (!node->as()) { for (auto & child : node->children) if (needVisitChild(child)) @@ -226,13 +225,13 @@ void QueryNormalizer::visit(ASTPtr & ast, Data & data) data.current_alias = my_alias; } - if (auto * node = typeid_cast(ast.get())) + if (auto * node = ast->as()) visit(*node, ast, data); - if (auto * node = typeid_cast(ast.get())) + if (auto * node = ast->as()) visit(*node, ast, data); - if (auto * node = typeid_cast(ast.get())) + if (auto * node = ast->as()) visit(*node, ast, data); - if (auto * node = typeid_cast(ast.get())) + if (auto * node = ast->as()) visit(*node, ast, data); /// If we replace the root of the subtree, we will be called again for the new root, in case the alias is replaced by an alias. diff --git a/dbms/src/Interpreters/RequiredSourceColumnsVisitor.cpp b/dbms/src/Interpreters/RequiredSourceColumnsVisitor.cpp index 8c24c916f22..f05cb6faa41 100644 --- a/dbms/src/Interpreters/RequiredSourceColumnsVisitor.cpp +++ b/dbms/src/Interpreters/RequiredSourceColumnsVisitor.cpp @@ -22,7 +22,7 @@ static std::vector extractNamesFromLambda(const ASTFunction & node) if (node.arguments->children.size() != 2) throw Exception("lambda requires two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - ASTFunction * lambda_args_tuple = typeid_cast(node.arguments->children[0].get()); + const auto * lambda_args_tuple = node.arguments->children[0]->as(); if (!lambda_args_tuple || lambda_args_tuple->name != "tuple") throw Exception("First argument of lambda must be a tuple", ErrorCodes::TYPE_MISMATCH); @@ -30,7 +30,7 @@ static std::vector extractNamesFromLambda(const ASTFunction & node) std::vector names; for (auto & child : lambda_args_tuple->arguments->children) { - ASTIdentifier * identifier = typeid_cast(child.get()); + const auto * identifier = child->as(); if (!identifier) throw Exception("lambda argument declarations must be identifiers", ErrorCodes::TYPE_MISMATCH); @@ -42,16 +42,14 @@ static std::vector extractNamesFromLambda(const ASTFunction & node) bool RequiredSourceColumnsMatcher::needChildVisit(ASTPtr & node, const ASTPtr & child) { - if (typeid_cast(child.get())) + if (child->as()) return false; /// Processed. Do not need children. - if (typeid_cast(node.get()) || - typeid_cast(node.get()) || - typeid_cast(node.get())) + if (node->as() || node->as() || node->as()) return false; - if (auto * f = typeid_cast(node.get())) + if (const auto * f = node->as()) { /// "indexHint" is a special function for index analysis. Everything that is inside it is not calculated. @sa KeyCondition /// "lambda" visit children itself. @@ -66,12 +64,12 @@ void RequiredSourceColumnsMatcher::visit(ASTPtr & ast, Data & data) { /// results are columns - if (auto * t = typeid_cast(ast.get())) + if (auto * t = ast->as()) { visit(*t, ast, data); return; } - if (auto * t = typeid_cast(ast.get())) + if (auto * t = ast->as()) { data.addColumnAliasIfAny(*ast); visit(*t, ast, data); @@ -80,24 +78,24 @@ void RequiredSourceColumnsMatcher::visit(ASTPtr & ast, Data & data) /// results are tables - if (auto * t = typeid_cast(ast.get())) + if (auto * t = ast->as()) { visit(*t, ast, data); return; } - if (auto * t = typeid_cast(ast.get())) + if (auto * t = ast->as()) { visit(*t, ast, data); return; } - if (auto * t = typeid_cast(ast.get())) + if (auto * t = ast->as()) { data.addTableAliasIfAny(*ast); visit(*t, ast, data); return; } - if (typeid_cast(ast.get())) + if (ast->as()) { data.addTableAliasIfAny(*ast); return; @@ -105,7 +103,7 @@ void RequiredSourceColumnsMatcher::visit(ASTPtr & ast, Data & data) /// other - if (auto * t = typeid_cast(ast.get())) + if (auto * t = ast->as()) { data.has_array_join = true; visit(*t, ast, data); @@ -118,7 +116,7 @@ void RequiredSourceColumnsMatcher::visit(ASTSelectQuery & select, const ASTPtr & /// special case for top-level SELECT items: they are publics for (auto & node : select.select_expression_list->children) { - if (auto * identifier = typeid_cast(node.get())) + if (const auto * identifier = node->as()) data.addColumnIdentifier(*identifier); else data.addColumnAliasIfAny(*node); @@ -170,9 +168,9 @@ void RequiredSourceColumnsMatcher::visit(ASTTablesInSelectQueryElement & node, c for (auto & child : node.children) { - if (auto * e = typeid_cast(child.get())) + if (auto * e = child->as()) expr = e; - if (auto * j = typeid_cast(child.get())) + if (auto * j = child->as()) join = j; } @@ -207,7 +205,7 @@ void RequiredSourceColumnsMatcher::visit(const ASTArrayJoin & node, const ASTPtr { data.addArrayJoinAliasIfAny(*expr); - if (auto * identifier = typeid_cast(expr.get())) + if (const auto * identifier = expr->as()) { data.addArrayJoinIdentifier(*identifier); continue; diff --git a/dbms/src/Interpreters/Set.cpp b/dbms/src/Interpreters/Set.cpp index a10d520d779..c1cf279aee0 100644 --- a/dbms/src/Interpreters/Set.cpp +++ b/dbms/src/Interpreters/Set.cpp @@ -205,13 +205,13 @@ bool Set::insertFromBlock(const Block & block) } -static Field extractValueFromNode(ASTPtr & node, const IDataType & type, const Context & context) +static Field extractValueFromNode(const ASTPtr & node, const IDataType & type, const Context & context) { - if (ASTLiteral * lit = typeid_cast(node.get())) + if (const auto * lit = node->as()) { return convertFieldToType(lit->value, type); } - else if (typeid_cast(node.get())) + else if (node->as()) { std::pair value_raw = evaluateConstantExpression(node, context); return convertFieldToType(value_raw.first, type, value_raw.second.get()); @@ -235,7 +235,7 @@ void Set::createFromAST(const DataTypes & types, ASTPtr node, const Context & co DataTypePtr tuple_type; Row tuple_values; - ASTExpressionList & list = typeid_cast(*node); + const auto & list = node->as(); for (auto & elem : list.children) { if (num_columns == 1) @@ -245,7 +245,7 @@ void Set::createFromAST(const DataTypes & types, ASTPtr node, const Context & co if (!value.isNull()) columns[0]->insert(value); } - else if (ASTFunction * func = typeid_cast(elem.get())) + else if (const auto * func = elem->as()) { Field function_result; const TupleBackend * tuple = nullptr; diff --git a/dbms/src/Interpreters/SyntaxAnalyzer.cpp b/dbms/src/Interpreters/SyntaxAnalyzer.cpp index bc808c3d37f..aab7bd3d4c2 100644 --- a/dbms/src/Interpreters/SyntaxAnalyzer.cpp +++ b/dbms/src/Interpreters/SyntaxAnalyzer.cpp @@ -63,7 +63,7 @@ using LogAST = DebugASTLog; /// set to true to enable logs /// Add columns from storage to source_columns list. -void collectSourceColumns(ASTSelectQuery * select_query, StoragePtr storage, NamesAndTypesList & source_columns) +void collectSourceColumns(const ASTSelectQuery * select_query, StoragePtr storage, NamesAndTypesList & source_columns) { if (storage) { @@ -112,12 +112,12 @@ void translateQualifiedNames(ASTPtr & query, const ASTSelectQuery & select_query bool hasArrayJoin(const ASTPtr & ast) { - if (const ASTFunction * function = typeid_cast(&*ast)) + if (const ASTFunction * function = ast->as()) if (function->name == "arrayJoin") return true; for (const auto & child : ast->children) - if (!typeid_cast(child.get()) && hasArrayJoin(child)) + if (!child->as() && hasArrayJoin(child)) return true; return false; @@ -213,9 +213,9 @@ void optimizeGroupBy(ASTSelectQuery * select_query, const NameSet & source_colum if (!select_query->group_expression_list) return; - const auto is_literal = [] (const ASTPtr & ast) + const auto is_literal = [] (const ASTPtr & ast) -> bool { - return typeid_cast(ast.get()); + return ast->as(); }; auto & group_exprs = select_query->group_expression_list->children; @@ -232,7 +232,7 @@ void optimizeGroupBy(ASTSelectQuery * select_query, const NameSet & source_colum /// iterate over each GROUP BY expression, eliminate injective function calls and literals for (size_t i = 0; i < group_exprs.size();) { - if (const auto function = typeid_cast(group_exprs[i].get())) + if (const auto * function = group_exprs[i]->as()) { /// assert function is injective if (possibly_injective_function_names.count(function->name)) @@ -244,13 +244,9 @@ void optimizeGroupBy(ASTSelectQuery * select_query, const NameSet & source_colum continue; } - const auto & dict_name = typeid_cast(*function->arguments->children[0]) - .value.safeGet(); - + const auto & dict_name = function->arguments->children[0]->as().value.safeGet(); const auto & dict_ptr = context.getExternalDictionaries().getDictionary(dict_name); - - const auto & attr_name = typeid_cast(*function->arguments->children[1]) - .value.safeGet(); + const auto & attr_name = function->arguments->children[1]->as().value.safeGet(); if (!dict_ptr->isInjective(attr_name)) { @@ -328,7 +324,7 @@ void optimizeOrderBy(const ASTSelectQuery * select_query) for (const auto & elem : elems) { String name = elem->children.front()->getColumnName(); - const ASTOrderByElement & order_by_elem = typeid_cast(*elem); + const auto & order_by_elem = elem->as(); if (elems_set.emplace(name, order_by_elem.collation ? order_by_elem.collation->getColumnName() : "").second) unique_elems.emplace_back(elem); @@ -363,11 +359,10 @@ void optimizeLimitBy(const ASTSelectQuery * select_query) /// Remove duplicated columns from USING(...). void optimizeUsing(const ASTSelectQuery * select_query) { - auto node = const_cast(select_query->join()); - if (!node) + if (!select_query->join()) return; - auto table_join = static_cast(&*node->table_join); + const auto * table_join = select_query->join()->table_join->as(); if (!(table_join && table_join->using_expression_list)) return; @@ -410,7 +405,7 @@ void getArrayJoinedColumns(ASTPtr & query, SyntaxAnalyzerResult & result, const String result_name = expr->getAliasOrColumnName(); /// This is an array. - if (!isIdentifier(expr) || source_columns_set.count(source_name)) + if (!expr->as() || source_columns_set.count(source_name)) { result.array_join_result_to_source[result_name] = source_name; } @@ -454,7 +449,7 @@ void collectJoinedColumnsFromJoinOnExpr(AnalyzedJoin & analyzed_join, const ASTT { if (IdentifierSemantic::getColumnName(ast)) { - auto * identifier = typeid_cast(ast.get()); + const auto * identifier = ast->as(); /// It's set in TranslateQualifiedNamesVisitor size_t membership = IdentifierSemantic::getMembership(*identifier); @@ -498,7 +493,7 @@ void collectJoinedColumnsFromJoinOnExpr(AnalyzedJoin & analyzed_join, const ASTT /// For equal expression find out corresponding table for each part, translate qualified names and add asts to join keys. auto add_columns_from_equals_expr = [&](const ASTPtr & expr) { - auto * func_equals = typeid_cast(expr.get()); + const auto * func_equals = expr->as(); if (!func_equals || func_equals->name != "equals") throwSyntaxException("Expected equals expression, got " + queryToString(expr) + "."); @@ -537,7 +532,7 @@ void collectJoinedColumnsFromJoinOnExpr(AnalyzedJoin & analyzed_join, const ASTT } }; - auto * func = typeid_cast(table_join.on_expression.get()); + const auto * func = table_join.on_expression->as(); if (func && func->name == "and") { for (const auto & expr : func->arguments->children) @@ -556,13 +551,13 @@ void collectJoinedColumns(AnalyzedJoin & analyzed_join, const ASTSelectQuery & s if (!node) return; - const auto & table_join = static_cast(*node->table_join); - const auto & table_expression = static_cast(*node->table_expression); + const auto & table_join = node->table_join->as(); + const auto & table_expression = node->table_expression->as(); DatabaseAndTableWithAlias joined_table_name(table_expression, current_database); if (table_join.using_expression_list) { - auto & keys = typeid_cast(*table_join.using_expression_list); + const auto & keys = table_join.using_expression_list->as(); for (const auto & key : keys.children) analyzed_join.addUsingKey(key); @@ -598,10 +593,10 @@ void replaceJoinedTable(const ASTTablesInSelectQueryElement* join) if (!join || !join->table_expression) return; - auto & table_expr = static_cast(*join->table_expression.get()); + auto & table_expr = join->table_expression->as(); if (table_expr.database_and_table_name) { - auto & table_id = typeid_cast(*table_expr.database_and_table_name.get()); + const auto & table_id = table_expr.database_and_table_name->as(); String expr = "(select * from " + table_id.name + ") as " + table_id.shortName(); // FIXME: since the expression "a as b" exposes both "a" and "b" names, which is not equivalent to "(select * from a) as b", @@ -610,7 +605,7 @@ void replaceJoinedTable(const ASTTablesInSelectQueryElement* join) if (table_id.alias.empty() && table_id.isShort()) { ParserTableExpression parser; - table_expr = static_cast(*parseQuery(parser, expr, 0)); + table_expr = parseQuery(parser, expr, 0)->as(); } } } @@ -624,7 +619,7 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyze( const Names & required_result_columns, StoragePtr storage) const { - auto * select_query = typeid_cast(query.get()); + auto * select_query = query->as(); if (!storage && select_query) { if (auto db_and_table = getDatabaseAndTable(*select_query, 0)) @@ -655,7 +650,7 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyze( if (settings.enable_optimize_predicate_expression) replaceJoinedTable(node); - const auto & joined_expression = static_cast(*node->table_expression); + const auto & joined_expression = node->table_expression->as(); DatabaseAndTableWithAlias table(joined_expression, context.getCurrentDatabase()); NamesAndTypesList joined_columns = getNamesAndTypeListFromTableExpression(joined_expression, context); diff --git a/dbms/src/Interpreters/TranslateQualifiedNamesVisitor.cpp b/dbms/src/Interpreters/TranslateQualifiedNamesVisitor.cpp index 3d8a67ae766..47ab1528c65 100644 --- a/dbms/src/Interpreters/TranslateQualifiedNamesVisitor.cpp +++ b/dbms/src/Interpreters/TranslateQualifiedNamesVisitor.cpp @@ -31,13 +31,11 @@ namespace ErrorCodes bool TranslateQualifiedNamesMatcher::needChildVisit(ASTPtr & node, const ASTPtr & child) { /// Do not go to FROM, JOIN, subqueries. - if (typeid_cast(child.get()) || - typeid_cast(child.get())) + if (child->as() || child->as()) return false; /// Processed nodes. Do not go into children. - if (typeid_cast(node.get()) || - typeid_cast(node.get())) + if (node->as() || node->as()) return false; /// ASTSelectQuery + others @@ -46,15 +44,15 @@ bool TranslateQualifiedNamesMatcher::needChildVisit(ASTPtr & node, const ASTPtr void TranslateQualifiedNamesMatcher::visit(ASTPtr & ast, Data & data) { - if (auto * t = typeid_cast(ast.get())) + if (auto * t = ast->as()) visit(*t, ast, data); - if (auto * t = typeid_cast(ast.get())) + if (auto * t = ast->as()) visit(*t, ast, data); - if (auto * t = typeid_cast(ast.get())) + if (auto * t = ast->as()) visit(*t, ast, data); - if (auto * node = typeid_cast(ast.get())) + if (auto * node = ast->as()) visit(*node, ast, data); - if (auto * node = typeid_cast(ast.get())) + if (auto * node = ast->as()) visit(*node, ast, data); } @@ -91,7 +89,7 @@ void TranslateQualifiedNamesMatcher::visit(ASTFunction & node, const ASTPtr &, D String func_name_lowercase = Poco::toLower(node.name); if (func_name_lowercase == "count" && func_arguments->children.size() == 1 && - typeid_cast(func_arguments->children[0].get())) + func_arguments->children[0]->as()) func_arguments->children.clear(); } @@ -173,14 +171,14 @@ void TranslateQualifiedNamesMatcher::visit(ASTExpressionList & node, const ASTPt bool has_asterisk = false; for (const auto & child : node.children) { - if (typeid_cast(child.get())) + if (child->as()) { if (tables_with_columns.empty()) throw Exception("An asterisk cannot be replaced with empty columns.", ErrorCodes::LOGICAL_ERROR); has_asterisk = true; break; } - else if (auto qa = typeid_cast(child.get())) + else if (const auto * qa = child->as()) { visit(*qa, child, data); /// check if it's OK before rewrite has_asterisk = true; @@ -197,7 +195,7 @@ void TranslateQualifiedNamesMatcher::visit(ASTExpressionList & node, const ASTPt for (const auto & child : old_children) { - if (const auto * asterisk = typeid_cast(child.get())) + if (const auto * asterisk = child->as()) { bool first_table = true; for (const auto & [table, table_columns] : tables_with_columns) @@ -214,7 +212,7 @@ void TranslateQualifiedNamesMatcher::visit(ASTExpressionList & node, const ASTPt first_table = false; } } - else if (const auto * qualified_asterisk = typeid_cast(child.get())) + else if (const auto * qualified_asterisk = child->as()) { DatabaseAndTableWithAlias ident_db_and_name(qualified_asterisk->children[0]); @@ -239,15 +237,15 @@ void TranslateQualifiedNamesMatcher::visit(ASTExpressionList & node, const ASTPt /// 'select * from a join b using id' should result one 'id' column void TranslateQualifiedNamesMatcher::extractJoinUsingColumns(const ASTPtr ast, Data & data) { - const auto & table_join = typeid_cast(*ast); + const auto & table_join = ast->as(); if (table_join.using_expression_list) { - auto & keys = typeid_cast(*table_join.using_expression_list); + const auto & keys = table_join.using_expression_list->as(); for (const auto & key : keys.children) if (auto opt_column = getIdentifierName(key)) data.join_using_columns.insert(*opt_column); - else if (typeid_cast(key.get())) + else if (key->as()) data.join_using_columns.insert(key->getColumnName()); else { diff --git a/dbms/src/Interpreters/evaluateConstantExpression.cpp b/dbms/src/Interpreters/evaluateConstantExpression.cpp index ccf29968a5c..7fe92b6d907 100644 --- a/dbms/src/Interpreters/evaluateConstantExpression.cpp +++ b/dbms/src/Interpreters/evaluateConstantExpression.cpp @@ -60,11 +60,11 @@ std::pair> evaluateConstantExpression(co ASTPtr evaluateConstantExpressionAsLiteral(const ASTPtr & node, const Context & context) { /// Branch with string in query. - if (typeid_cast(node.get())) + if (node->as()) return node; /// Branch with TableFunction in query. - if (auto table_func_ptr = typeid_cast(node.get())) + if (const auto * table_func_ptr = node->as()) if (TableFunctionFactory::instance().isTableFunctionName(table_func_ptr->name)) return node; @@ -73,7 +73,7 @@ ASTPtr evaluateConstantExpressionAsLiteral(const ASTPtr & node, const Context & ASTPtr evaluateConstantExpressionOrIdentifierAsLiteral(const ASTPtr & node, const Context & context) { - if (auto id = typeid_cast(node.get())) + if (const auto * id = node->as()) return std::make_shared(id->name); return evaluateConstantExpressionAsLiteral(node, context); @@ -145,10 +145,8 @@ namespace { const auto * left = fn->arguments->children.front().get(); const auto * right = fn->arguments->children.back().get(); - const auto * identifier = typeid_cast(left) ? typeid_cast(left) - : typeid_cast(right); - const auto * literal = typeid_cast(left) ? typeid_cast(left) - : typeid_cast(right); + const auto * identifier = left->as() ? left->as() : right->as(); + const auto * literal = left->as() ? left->as() : right->as(); return analyzeEquals(identifier, literal, expr); } @@ -156,15 +154,15 @@ namespace { const auto * left = fn->arguments->children.front().get(); const auto * right = fn->arguments->children.back().get(); - const auto * identifier = typeid_cast(left); - const auto * inner_fn = typeid_cast(right); + const auto * identifier = left->as(); + const auto * inner_fn = right->as(); if (!inner_fn) { return {}; } - const auto * tuple = typeid_cast(inner_fn->children.front().get()); + const auto * tuple = inner_fn->children.front()->as(); if (!tuple) { @@ -175,7 +173,7 @@ namespace for (const auto & child : tuple->children) { - const auto * literal = typeid_cast(child.get()); + const auto * literal = child->as(); const auto dnf = analyzeEquals(identifier, literal, expr); if (dnf.empty()) @@ -190,7 +188,7 @@ namespace } else if (fn->name == "or") { - const auto * args = typeid_cast(fn->children.front().get()); + const auto * args = fn->children.front()->as(); if (!args) { @@ -201,7 +199,7 @@ namespace for (const auto & arg : args->children) { - const auto dnf = analyzeFunction(typeid_cast(arg.get()), expr); + const auto dnf = analyzeFunction(arg->as(), expr); if (dnf.empty()) { @@ -215,7 +213,7 @@ namespace } else if (fn->name == "and") { - const auto * args = typeid_cast(fn->children.front().get()); + const auto * args = fn->children.front()->as(); if (!args) { @@ -226,7 +224,7 @@ namespace for (const auto & arg : args->children) { - const auto dnf = analyzeFunction(typeid_cast(arg.get()), expr); + const auto dnf = analyzeFunction(arg->as(), expr); if (dnf.empty()) { @@ -249,7 +247,7 @@ std::optional evaluateExpressionOverConstantCondition(const ASTPtr & nod // TODO: `node` may be always-false literal. - if (const auto fn = typeid_cast(node.get())) + if (const auto * fn = node->as()) { const auto dnf = analyzeFunction(fn, target_expr); diff --git a/dbms/src/Interpreters/executeQuery.cpp b/dbms/src/Interpreters/executeQuery.cpp index 069c5c67abc..56f515e91c8 100644 --- a/dbms/src/Interpreters/executeQuery.cpp +++ b/dbms/src/Interpreters/executeQuery.cpp @@ -169,7 +169,7 @@ static std::tuple executeQueryImpl( /// TODO Parser should fail early when max_query_size limit is reached. ast = parseQuery(parser, begin, end, "", max_query_size); - auto * insert_query = dynamic_cast(ast.get()); + auto * insert_query = ast->as(); if (insert_query && insert_query->data) { query_end = insert_query->data; @@ -208,7 +208,7 @@ static std::tuple executeQueryImpl( /// Put query to process list. But don't put SHOW PROCESSLIST query itself. ProcessList::EntryPtr process_list_entry; - if (!internal && nullptr == typeid_cast(&*ast)) + if (!internal && !ast->as()) { process_list_entry = context.getProcessList().insert(query, ast.get(), context); context.setProcessListElement(&process_list_entry->get()); @@ -488,7 +488,8 @@ void executeQuery( if (streams.in) { - const ASTQueryWithOutput * ast_query_with_output = dynamic_cast(ast.get()); + /// FIXME: try to prettify this cast using `as<>()` + const auto * ast_query_with_output = dynamic_cast(ast.get()); WriteBuffer * out_buf = &ostr; std::optional out_file_buf; @@ -497,7 +498,7 @@ void executeQuery( if (!allow_into_outfile) throw Exception("INTO OUTFILE is not allowed", ErrorCodes::INTO_OUTFILE_NOT_ALLOWED); - const auto & out_file = typeid_cast(*ast_query_with_output->out_file).value.safeGet(); + const auto & out_file = ast_query_with_output->out_file->as().value.safeGet(); out_file_buf.emplace(out_file, DBMS_DEFAULT_BUFFER_SIZE, O_WRONLY | O_EXCL | O_CREAT); out_buf = &*out_file_buf; } diff --git a/dbms/src/Interpreters/getClusterName.cpp b/dbms/src/Interpreters/getClusterName.cpp index bc32e3dbea7..d162cbdab9d 100644 --- a/dbms/src/Interpreters/getClusterName.cpp +++ b/dbms/src/Interpreters/getClusterName.cpp @@ -18,14 +18,14 @@ namespace ErrorCodes std::string getClusterName(const IAST & node) { - if (const ASTIdentifier * ast_id = typeid_cast(&node)) + if (const auto * ast_id = node.as()) return ast_id->name; - if (const ASTLiteral * ast_lit = typeid_cast(&node)) + if (const auto * ast_lit = node.as()) return ast_lit->value.safeGet(); /// A hack to support hyphens in cluster names. - if (const ASTFunction * ast_func = typeid_cast(&node)) + if (const auto * ast_func = node.as()) { if (ast_func->name != "minus" || !ast_func->arguments || ast_func->arguments->children.size() < 2) throw Exception("Illegal expression instead of cluster name.", ErrorCodes::BAD_ARGUMENTS); diff --git a/dbms/src/Interpreters/interpretSubquery.cpp b/dbms/src/Interpreters/interpretSubquery.cpp index a585f7edc42..d46217695f9 100644 --- a/dbms/src/Interpreters/interpretSubquery.cpp +++ b/dbms/src/Interpreters/interpretSubquery.cpp @@ -19,9 +19,9 @@ std::shared_ptr interpretSubquery( const ASTPtr & table_expression, const Context & context, size_t subquery_depth, const Names & required_source_columns) { /// Subquery or table name. The name of the table is similar to the subquery `SELECT * FROM t`. - const ASTSubquery * subquery = typeid_cast(table_expression.get()); - const ASTFunction * function = typeid_cast(table_expression.get()); - const ASTIdentifier * table = typeid_cast(table_expression.get()); + const auto * subquery = table_expression->as(); + const auto * function = table_expression->as(); + const auto * table = table_expression->as(); if (!subquery && !table && !function) throw Exception("Table expression is undefined, Method: ExpressionAnalyzer::interpretSubquery." , ErrorCodes::LOGICAL_ERROR); @@ -65,7 +65,7 @@ std::shared_ptr interpretSubquery( auto query_context = const_cast(&context.getQueryContext()); const auto & storage = query_context->executeTableFunction(table_expression); columns = storage->getColumns().ordinary; - select_query->addTableFunction(*const_cast(&table_expression)); + select_query->addTableFunction(*const_cast(&table_expression)); // XXX: const_cast should be avoided! } else { @@ -94,9 +94,9 @@ std::shared_ptr interpretSubquery( std::set all_column_names; std::set assigned_column_names; - if (ASTSelectWithUnionQuery * select_with_union = typeid_cast(query.get())) + if (const auto * select_with_union = query->as()) { - if (ASTSelectQuery * select = typeid_cast(select_with_union->list_of_selects->children.at(0).get())) + if (const auto * select = select_with_union->list_of_selects->children.at(0)->as()) { for (auto & expr : select->select_expression_list->children) all_column_names.insert(expr->getAliasOrColumnName()); diff --git a/dbms/src/Interpreters/loadMetadata.cpp b/dbms/src/Interpreters/loadMetadata.cpp index 2a94a6d2ff1..e0caa8f433d 100644 --- a/dbms/src/Interpreters/loadMetadata.cpp +++ b/dbms/src/Interpreters/loadMetadata.cpp @@ -39,7 +39,7 @@ static void executeCreateQuery( ParserCreateQuery parser; ASTPtr ast = parseQuery(parser, query.data(), query.data() + query.size(), "in file " + file_name, 0); - ASTCreateQuery & ast_create_query = typeid_cast(*ast); + auto & ast_create_query = ast->as(); ast_create_query.attach = true; ast_create_query.database = database; diff --git a/dbms/src/Parsers/ASTAlterQuery.h b/dbms/src/Parsers/ASTAlterQuery.h index 2c77e2031de..7261170288a 100644 --- a/dbms/src/Parsers/ASTAlterQuery.h +++ b/dbms/src/Parsers/ASTAlterQuery.h @@ -123,7 +123,7 @@ public: void add(const ASTPtr & command) { - commands.push_back(static_cast(command.get())); + commands.push_back(command->as()); children.push_back(command); } diff --git a/dbms/src/Parsers/ASTFunction.cpp b/dbms/src/Parsers/ASTFunction.cpp index 2bf521571aa..b45bded9664 100644 --- a/dbms/src/Parsers/ASTFunction.cpp +++ b/dbms/src/Parsers/ASTFunction.cpp @@ -62,7 +62,7 @@ ASTPtr ASTFunction::clone() const */ static bool highlightStringLiteralWithMetacharacters(const ASTPtr & node, const IAST::FormatSettings & settings, const char * metacharacters) { - if (auto literal = dynamic_cast(node.get())) + if (const auto * literal = node->as()) { if (literal->value.getType() == Field::Types::String) { @@ -132,7 +132,7 @@ void ASTFunction::formatImplWithoutAlias(const FormatSettings & settings, Format * Instead, add a space. * PS. You can not just ask to add parentheses - see formatImpl for ASTLiteral. */ - if (name == "negate" && typeid_cast(&*arguments->children[0])) + if (name == "negate" && arguments->children[0]->as()) settings.ostr << ' '; arguments->formatImpl(settings, state, nested_need_parens); @@ -203,7 +203,7 @@ void ASTFunction::formatImplWithoutAlias(const FormatSettings & settings, Format if (!written && 0 == strcmp(name.c_str(), "tupleElement")) { /// It can be printed in a form of 'x.1' only if right hand side is unsigned integer literal. - if (const ASTLiteral * lit = typeid_cast(arguments->children[1].get())) + if (const auto * lit = arguments->children[1]->as()) { if (lit->value.getType() == Field::Types::UInt64) { @@ -222,7 +222,7 @@ void ASTFunction::formatImplWithoutAlias(const FormatSettings & settings, Format if (frame.need_parens) settings.ostr << '('; - const ASTFunction * first_arg_func = typeid_cast(arguments->children[0].get()); + const auto * first_arg_func = arguments->children[0]->as(); if (first_arg_func && first_arg_func->name == "tuple" && first_arg_func->arguments diff --git a/dbms/src/Parsers/ASTIdentifier.cpp b/dbms/src/Parsers/ASTIdentifier.cpp index 1a9db37391a..6b0329409a3 100644 --- a/dbms/src/Parsers/ASTIdentifier.cpp +++ b/dbms/src/Parsers/ASTIdentifier.cpp @@ -83,17 +83,10 @@ ASTPtr createTableIdentifier(const String & database_name, const String & table_ return database_and_table; } -bool isIdentifier(const IAST * const ast) -{ - if (ast) - return typeid_cast(ast); - return false; -} - std::optional getIdentifierName(const IAST * const ast) { if (ast) - if (auto node = typeid_cast(ast)) + if (const auto * node = ast->as()) return node->name; return {}; } @@ -101,7 +94,7 @@ std::optional getIdentifierName(const IAST * const ast) bool getIdentifierName(const ASTPtr & ast, String & name) { if (ast) - if (auto node = typeid_cast(ast.get())) + if (const auto * node = ast->as()) { name = node->name; return true; @@ -112,7 +105,7 @@ bool getIdentifierName(const ASTPtr & ast, String & name) void setIdentifierSpecial(ASTPtr & ast) { if (ast) - if (ASTIdentifier * id = typeid_cast(ast.get())) + if (auto * id = ast->as()) id->semantic->special = true; } diff --git a/dbms/src/Parsers/ASTIdentifier.h b/dbms/src/Parsers/ASTIdentifier.h index 1439ab2dcbd..434f84eb77e 100644 --- a/dbms/src/Parsers/ASTIdentifier.h +++ b/dbms/src/Parsers/ASTIdentifier.h @@ -69,9 +69,6 @@ private: ASTPtr createTableIdentifier(const String & database_name, const String & table_name); void setIdentifierSpecial(ASTPtr & ast); -bool isIdentifier(const IAST * const ast); -inline bool isIdentifier(const ASTPtr & ast) { return isIdentifier(ast.get()); } - std::optional getIdentifierName(const IAST * const ast); inline std::optional getIdentifierName(const ASTPtr & ast) { return getIdentifierName(ast.get()); } bool getIdentifierName(const ASTPtr & ast, String & name); diff --git a/dbms/src/Parsers/ASTQueryWithOutput.cpp b/dbms/src/Parsers/ASTQueryWithOutput.cpp index c28b15cb8c1..95bcaaad416 100644 --- a/dbms/src/Parsers/ASTQueryWithOutput.cpp +++ b/dbms/src/Parsers/ASTQueryWithOutput.cpp @@ -49,7 +49,8 @@ void ASTQueryWithOutput::formatImpl(const FormatSettings & s, FormatState & stat bool ASTQueryWithOutput::resetOutputASTIfExist(IAST & ast) { - if (auto ast_with_output = dynamic_cast(&ast)) + /// FIXME: try to prettify this cast using `as<>()` + if (auto * ast_with_output = dynamic_cast(&ast)) { ast_with_output->format.reset(); ast_with_output->out_file.reset(); diff --git a/dbms/src/Parsers/ASTRenameQuery.h b/dbms/src/Parsers/ASTRenameQuery.h index 006c8583836..1666873ed9c 100644 --- a/dbms/src/Parsers/ASTRenameQuery.h +++ b/dbms/src/Parsers/ASTRenameQuery.h @@ -41,7 +41,7 @@ public: ASTPtr getRewrittenASTWithoutOnCluster(const std::string & new_database) const override { auto query_ptr = clone(); - auto & query = static_cast(*query_ptr); + auto & query = query_ptr->as(); query.cluster.clear(); for (Element & elem : query.elements) diff --git a/dbms/src/Parsers/ASTSelectQuery.cpp b/dbms/src/Parsers/ASTSelectQuery.cpp index 3534518c1b9..b7ad77ba48d 100644 --- a/dbms/src/Parsers/ASTSelectQuery.cpp +++ b/dbms/src/Parsers/ASTSelectQuery.cpp @@ -65,7 +65,7 @@ void ASTSelectQuery::formatImpl(const FormatSettings & s, FormatState & state, F s.ostr << (s.hilite ? hilite_keyword : "") << indent_str << "WITH " << (s.hilite ? hilite_none : ""); s.one_line ? with_expression_list->formatImpl(s, state, frame) - : typeid_cast(*with_expression_list).formatImplMultiline(s, state, frame); + : with_expression_list->as().formatImplMultiline(s, state, frame); s.ostr << s.nl_or_ws; } @@ -73,7 +73,7 @@ void ASTSelectQuery::formatImpl(const FormatSettings & s, FormatState & state, F s.one_line ? select_expression_list->formatImpl(s, state, frame) - : typeid_cast(*select_expression_list).formatImplMultiline(s, state, frame); + : select_expression_list->as().formatImplMultiline(s, state, frame); if (tables) { @@ -98,7 +98,7 @@ void ASTSelectQuery::formatImpl(const FormatSettings & s, FormatState & state, F s.ostr << (s.hilite ? hilite_keyword : "") << s.nl_or_ws << indent_str << "GROUP BY " << (s.hilite ? hilite_none : ""); s.one_line ? group_expression_list->formatImpl(s, state, frame) - : typeid_cast(*group_expression_list).formatImplMultiline(s, state, frame); + : group_expression_list->as().formatImplMultiline(s, state, frame); } if (group_by_with_rollup) @@ -121,7 +121,7 @@ void ASTSelectQuery::formatImpl(const FormatSettings & s, FormatState & state, F s.ostr << (s.hilite ? hilite_keyword : "") << s.nl_or_ws << indent_str << "ORDER BY " << (s.hilite ? hilite_none : ""); s.one_line ? order_expression_list->formatImpl(s, state, frame) - : typeid_cast(*order_expression_list).formatImplMultiline(s, state, frame); + : order_expression_list->as().formatImplMultiline(s, state, frame); } if (limit_by_value) @@ -131,7 +131,7 @@ void ASTSelectQuery::formatImpl(const FormatSettings & s, FormatState & state, F s.ostr << (s.hilite ? hilite_keyword : "") << " BY " << (s.hilite ? hilite_none : ""); s.one_line ? limit_by_expression_list->formatImpl(s, state, frame) - : typeid_cast(*limit_by_expression_list).formatImplMultiline(s, state, frame); + : limit_by_expression_list->as().formatImplMultiline(s, state, frame); } if (limit_length) @@ -161,15 +161,15 @@ static const ASTTableExpression * getFirstTableExpression(const ASTSelectQuery & if (!select.tables) return {}; - const ASTTablesInSelectQuery & tables_in_select_query = static_cast(*select.tables); + const auto & tables_in_select_query = select.tables->as(); if (tables_in_select_query.children.empty()) return {}; - const ASTTablesInSelectQueryElement & tables_element = static_cast(*tables_in_select_query.children[0]); + const auto & tables_element = tables_in_select_query.children[0]->as(); if (!tables_element.table_expression) return {}; - return static_cast(tables_element.table_expression.get()); + return tables_element.table_expression->as(); } static ASTTableExpression * getFirstTableExpression(ASTSelectQuery & select) @@ -177,15 +177,15 @@ static ASTTableExpression * getFirstTableExpression(ASTSelectQuery & select) if (!select.tables) return {}; - ASTTablesInSelectQuery & tables_in_select_query = static_cast(*select.tables); + auto & tables_in_select_query = select.tables->as(); if (tables_in_select_query.children.empty()) return {}; - ASTTablesInSelectQueryElement & tables_element = static_cast(*tables_in_select_query.children[0]); + auto & tables_element = tables_in_select_query.children[0]->as(); if (!tables_element.table_expression) return {}; - return static_cast(tables_element.table_expression.get()); + return tables_element.table_expression->as(); } static const ASTArrayJoin * getFirstArrayJoin(const ASTSelectQuery & select) @@ -193,18 +193,18 @@ static const ASTArrayJoin * getFirstArrayJoin(const ASTSelectQuery & select) if (!select.tables) return {}; - const ASTTablesInSelectQuery & tables_in_select_query = static_cast(*select.tables); + const auto & tables_in_select_query = select.tables->as(); if (tables_in_select_query.children.empty()) return {}; const ASTArrayJoin * array_join = nullptr; for (const auto & child : tables_in_select_query.children) { - const ASTTablesInSelectQueryElement & tables_element = static_cast(*child); + const auto & tables_element = child->as(); if (tables_element.array_join) { if (!array_join) - array_join = static_cast(tables_element.array_join.get()); + array_join = tables_element.array_join->as(); else throw Exception("Support for more than one ARRAY JOIN in query is not implemented", ErrorCodes::NOT_IMPLEMENTED); } @@ -218,14 +218,14 @@ static const ASTTablesInSelectQueryElement * getFirstTableJoin(const ASTSelectQu if (!select.tables) return {}; - const ASTTablesInSelectQuery & tables_in_select_query = static_cast(*select.tables); + const auto & tables_in_select_query = select.tables->as(); if (tables_in_select_query.children.empty()) return {}; const ASTTablesInSelectQueryElement * joined_table = nullptr; for (const auto & child : tables_in_select_query.children) { - const ASTTablesInSelectQueryElement & tables_element = static_cast(*child); + const auto & tables_element = child->as(); if (tables_element.table_join) { if (!joined_table) @@ -357,4 +357,3 @@ void ASTSelectQuery::addTableFunction(ASTPtr & table_function_ptr) } } - diff --git a/dbms/src/Parsers/ASTTablesInSelectQuery.cpp b/dbms/src/Parsers/ASTTablesInSelectQuery.cpp index 1f2adfb17cd..fb278046377 100644 --- a/dbms/src/Parsers/ASTTablesInSelectQuery.cpp +++ b/dbms/src/Parsers/ASTTablesInSelectQuery.cpp @@ -208,7 +208,7 @@ void ASTArrayJoin::formatImpl(const FormatSettings & settings, FormatState & sta settings.one_line ? expression_list->formatImpl(settings, state, frame) - : typeid_cast(*expression_list).formatImplMultiline(settings, state, frame); + : expression_list->as().formatImplMultiline(settings, state, frame); } @@ -218,7 +218,7 @@ void ASTTablesInSelectQueryElement::formatImpl(const FormatSettings & settings, { if (table_join) { - static_cast(*table_join).formatImplBeforeTable(settings, state, frame); + table_join->as().formatImplBeforeTable(settings, state, frame); settings.ostr << " "; } @@ -226,7 +226,7 @@ void ASTTablesInSelectQueryElement::formatImpl(const FormatSettings & settings, settings.ostr << " "; if (table_join) - static_cast(*table_join).formatImplAfterTable(settings, state, frame); + table_join->as().formatImplAfterTable(settings, state, frame); } else if (array_join) { diff --git a/dbms/src/Parsers/ExpressionElementParsers.cpp b/dbms/src/Parsers/ExpressionElementParsers.cpp index e38689467ea..4cde7219fdb 100644 --- a/dbms/src/Parsers/ExpressionElementParsers.cpp +++ b/dbms/src/Parsers/ExpressionElementParsers.cpp @@ -82,7 +82,7 @@ bool ParserParenthesisExpression::parseImpl(Pos & pos, ASTPtr & node, Expected & return false; ++pos; - ASTExpressionList & expr_list = typeid_cast(*contents_node); + const auto & expr_list = contents_node->as(); /// empty expression in parentheses is not allowed if (expr_list.children.empty()) @@ -125,7 +125,7 @@ bool ParserSubquery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) ++pos; node = std::make_shared(); - typeid_cast(*node).children.push_back(select_node); + node->children.push_back(select_node); return true; } @@ -170,7 +170,7 @@ bool ParserCompoundIdentifier::parseImpl(Pos & pos, ASTPtr & node, Expected & ex String name; std::vector parts; - const ASTExpressionList & list = static_cast(*id_list.get()); + const auto & list = id_list->as(); for (const auto & child : list.children) { if (!name.empty()) @@ -1075,7 +1075,7 @@ bool ParserArrayOfLiterals::parseImpl(Pos & pos, ASTPtr & node, Expected & expec if (!literal_p.parse(pos, literal_node, expected)) return false; - arr.push_back(typeid_cast(*literal_node).value); + arr.push_back(literal_node->as().value); } expected.add(pos, "closing square bracket"); @@ -1254,7 +1254,8 @@ bool ParserWithOptionalAlias::parseImpl(Pos & pos, ASTPtr & node, Expected & exp ASTPtr alias_node; if (ParserAlias(allow_alias_without_as_keyword_now).parse(pos, alias_node, expected)) { - if (ASTWithAlias * ast_with_alias = dynamic_cast(node.get())) + /// FIXME: try to prettify this cast using `as<>()` + if (auto * ast_with_alias = dynamic_cast(node.get())) { getIdentifierName(alias_node, ast_with_alias->alias); ast_with_alias->prefer_alias_to_column_name = prefer_alias_to_column_name; @@ -1325,4 +1326,3 @@ bool ParserOrderByElement::parseImpl(Pos & pos, ASTPtr & node, Expected & expect } } - diff --git a/dbms/src/Parsers/ExpressionListParsers.cpp b/dbms/src/Parsers/ExpressionListParsers.cpp index 00f72d6a369..b948a22ce2a 100644 --- a/dbms/src/Parsers/ExpressionListParsers.cpp +++ b/dbms/src/Parsers/ExpressionListParsers.cpp @@ -213,7 +213,7 @@ bool ParserVariableArityOperatorList::parseImpl(Pos & pos, ASTPtr & node, Expect if (!arguments) { node = makeASTFunction(function_name, node); - arguments = static_cast(*node).arguments; + arguments = node->as().arguments; } ASTPtr elem; @@ -540,8 +540,7 @@ bool ParserExpressionList::parseImpl(Pos & pos, ASTPtr & node, Expected & expect bool ParserNotEmptyExpressionList::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { - return nested_parser.parse(pos, node, expected) - && !typeid_cast(*node).children.empty(); + return nested_parser.parse(pos, node, expected) && !node->children.empty(); } diff --git a/dbms/src/Parsers/IAST.h b/dbms/src/Parsers/IAST.h index 7692691073d..65572d922d2 100644 --- a/dbms/src/Parsers/IAST.h +++ b/dbms/src/Parsers/IAST.h @@ -1,13 +1,14 @@ #pragma once -#include -#include -#include -#include - #include -#include +#include #include +#include +#include + +#include +#include +#include class SipHash; @@ -26,16 +27,12 @@ namespace ErrorCodes using IdentifierNameSet = std::set; -class IAST; -using ASTPtr = std::shared_ptr; -using ASTs = std::vector; - class WriteBuffer; /** Element of the syntax tree (hereinafter - directed acyclic graph with elements of semantics) */ -class IAST : public std::enable_shared_from_this +class IAST : public std::enable_shared_from_this, public TypePromotion { public: ASTs children; diff --git a/dbms/src/Parsers/IAST_fwd.h b/dbms/src/Parsers/IAST_fwd.h new file mode 100644 index 00000000000..30408a3792f --- /dev/null +++ b/dbms/src/Parsers/IAST_fwd.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include + +namespace DB +{ + +class IAST; +using ASTPtr = std::shared_ptr; +using ASTs = std::vector; + +} // namespace DB diff --git a/dbms/src/Parsers/ParserAlterQuery.cpp b/dbms/src/Parsers/ParserAlterQuery.cpp index 818362e9c95..b33679ad26b 100644 --- a/dbms/src/Parsers/ParserAlterQuery.cpp +++ b/dbms/src/Parsers/ParserAlterQuery.cpp @@ -202,7 +202,7 @@ bool ParserAlterCommand::parseImpl(Pos & pos, ASTPtr & node, Expected & expected if (!parser_string_literal.parse(pos, ast_from, expected)) return false; - command->from = typeid_cast(*ast_from).value.get(); + command->from = ast_from->as().value.get(); command->type = ASTAlterCommand::FETCH_PARTITION; } else if (s_freeze.ignore(pos, expected)) @@ -229,7 +229,7 @@ bool ParserAlterCommand::parseImpl(Pos & pos, ASTPtr & node, Expected & expected if (!parser_string_literal.parse(pos, ast_with_name, expected)) return false; - command->with_name = typeid_cast(*ast_with_name).value.get(); + command->with_name = ast_with_name->as().value.get(); } } else if (s_modify_column.ignore(pos, expected)) diff --git a/dbms/src/Parsers/ParserCreateQuery.cpp b/dbms/src/Parsers/ParserCreateQuery.cpp index 5a8ad919b58..72610951868 100644 --- a/dbms/src/Parsers/ParserCreateQuery.cpp +++ b/dbms/src/Parsers/ParserCreateQuery.cpp @@ -126,8 +126,8 @@ bool ParserIndexDeclaration::parseImpl(Pos & pos, ASTPtr & node, Expected & expe return false; auto index = std::make_shared(); - index->name = typeid_cast(*name).name; - index->granularity = typeid_cast(*granularity).value.get(); + index->name = name->as().name; + index->granularity = granularity->as().value.get(); index->set(index->expr, expr); index->set(index->type, type); node = index; @@ -179,9 +179,9 @@ bool ParserColumnsOrIndicesDeclarationList::parseImpl(Pos & pos, ASTPtr & node, for (const auto & elem : list->children) { - if (typeid_cast(elem.get())) + if (elem->as()) columns->children.push_back(elem); - else if (typeid_cast(elem.get())) + else if (elem->as()) indices->children.push_back(elem); else return false; diff --git a/dbms/src/Parsers/ParserDropQuery.cpp b/dbms/src/Parsers/ParserDropQuery.cpp index c3a97a222d2..ca757ae6168 100644 --- a/dbms/src/Parsers/ParserDropQuery.cpp +++ b/dbms/src/Parsers/ParserDropQuery.cpp @@ -34,7 +34,7 @@ bool ParserDropQuery::parseDetachQuery(Pos & pos, ASTPtr & node, Expected & expe { if (parseDropQuery(pos, node, expected)) { - ASTDropQuery * drop_query = static_cast(node.get()); + auto * drop_query = node->as(); drop_query->kind = ASTDropQuery::Kind::Detach; return true; } @@ -45,7 +45,7 @@ bool ParserDropQuery::parseTruncateQuery(Pos & pos, ASTPtr & node, Expected & ex { if (parseDropQuery(pos, node, expected)) { - ASTDropQuery * drop_query = static_cast(node.get()); + auto * drop_query = node->as(); drop_query->kind = ASTDropQuery::Kind::Truncate; return true; } diff --git a/dbms/src/Parsers/ParserPartition.cpp b/dbms/src/Parsers/ParserPartition.cpp index 1daf4dead18..6d2c259f8bf 100644 --- a/dbms/src/Parsers/ParserPartition.cpp +++ b/dbms/src/Parsers/ParserPartition.cpp @@ -26,7 +26,7 @@ bool ParserPartition::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) if (!parser_string_literal.parse(pos, partition_id, expected)) return false; - partition->id = dynamic_cast(*partition_id).value.get(); + partition->id = partition_id->as().value.get(); } else { @@ -37,10 +37,10 @@ bool ParserPartition::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) size_t fields_count; StringRef fields_str; - const auto * tuple_ast = typeid_cast(value.get()); + const auto * tuple_ast = value->as(); if (tuple_ast && tuple_ast->name == "tuple") { - const auto * arguments_ast = dynamic_cast(tuple_ast->arguments.get()); + const auto * arguments_ast = tuple_ast->arguments->as(); if (arguments_ast) fields_count = arguments_ast->children.size(); else diff --git a/dbms/src/Parsers/ParserQueryWithOutput.cpp b/dbms/src/Parsers/ParserQueryWithOutput.cpp index d1679067854..c41e0946a96 100644 --- a/dbms/src/Parsers/ParserQueryWithOutput.cpp +++ b/dbms/src/Parsers/ParserQueryWithOutput.cpp @@ -62,6 +62,7 @@ bool ParserQueryWithOutput::parseImpl(Pos & pos, ASTPtr & node, Expected & expec if (!parsed) return false; + /// FIXME: try to prettify this cast using `as<>()` auto & query_with_output = dynamic_cast(*query); ParserKeyword s_into_outfile("INTO OUTFILE"); diff --git a/dbms/src/Parsers/ParserSelectWithUnionQuery.cpp b/dbms/src/Parsers/ParserSelectWithUnionQuery.cpp index a0935074771..cebe8ba876d 100644 --- a/dbms/src/Parsers/ParserSelectWithUnionQuery.cpp +++ b/dbms/src/Parsers/ParserSelectWithUnionQuery.cpp @@ -11,7 +11,7 @@ namespace DB static void getSelectsFromUnionListNode(ASTPtr & ast_select, ASTs & selects) { - if (ASTSelectWithUnionQuery * inner_union = typeid_cast(ast_select.get())) + if (auto * inner_union = ast_select->as()) { for (auto & child : inner_union->list_of_selects->children) getSelectsFromUnionListNode(child, selects); diff --git a/dbms/src/Parsers/ParserSetQuery.cpp b/dbms/src/Parsers/ParserSetQuery.cpp index 14b5b4bec5e..9eb94a76364 100644 --- a/dbms/src/Parsers/ParserSetQuery.cpp +++ b/dbms/src/Parsers/ParserSetQuery.cpp @@ -32,7 +32,7 @@ static bool parseNameValuePair(ASTSetQuery::Change & change, IParser::Pos & pos, return false; getIdentifierName(name, change.name); - change.value = typeid_cast(*value).value; + change.value = value->as().value; return true; } diff --git a/dbms/src/Parsers/ParserShowTablesQuery.cpp b/dbms/src/Parsers/ParserShowTablesQuery.cpp index dc854883cfe..9c247a284c1 100644 --- a/dbms/src/Parsers/ParserShowTablesQuery.cpp +++ b/dbms/src/Parsers/ParserShowTablesQuery.cpp @@ -67,7 +67,7 @@ bool ParserShowTablesQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expec getIdentifierName(database, query->from); if (like) - query->like = safeGet(typeid_cast(*like).value); + query->like = safeGet(like->as().value); node = query; diff --git a/dbms/src/Parsers/ParserUnionQueryElement.cpp b/dbms/src/Parsers/ParserUnionQueryElement.cpp index e6b8ba66cb0..efd022e6362 100644 --- a/dbms/src/Parsers/ParserUnionQueryElement.cpp +++ b/dbms/src/Parsers/ParserUnionQueryElement.cpp @@ -13,7 +13,7 @@ bool ParserUnionQueryElement::parseImpl(Pos & pos, ASTPtr & node, Expected & exp if (!ParserSubquery().parse(pos, node, expected) && !ParserSelectQuery().parse(pos, node, expected)) return false; - if (auto * ast_subquery = typeid_cast(node.get())) + if (const auto * ast_subquery = node->as()) node = ast_subquery->children.at(0); return true; diff --git a/dbms/src/Parsers/parseIdentifierOrStringLiteral.cpp b/dbms/src/Parsers/parseIdentifierOrStringLiteral.cpp index 2fa71415efb..815a5d3f3cc 100644 --- a/dbms/src/Parsers/parseIdentifierOrStringLiteral.cpp +++ b/dbms/src/Parsers/parseIdentifierOrStringLiteral.cpp @@ -17,7 +17,7 @@ bool parseIdentifierOrStringLiteral(IParser::Pos & pos, Expected & expected, Str if (!ParserStringLiteral().parse(pos, res, expected)) return false; - result = typeid_cast(*res).value.safeGet(); + result = res->as().value.safeGet(); } else result = *getIdentifierName(res); diff --git a/dbms/src/Parsers/parseQuery.cpp b/dbms/src/Parsers/parseQuery.cpp index 4f8ab83b7fd..3d761d09b13 100644 --- a/dbms/src/Parsers/parseQuery.cpp +++ b/dbms/src/Parsers/parseQuery.cpp @@ -236,7 +236,7 @@ ASTPtr tryParseQuery( /// If parsed query ends at data for insertion. Data for insertion could be in any format and not necessary be lexical correct. ASTInsertQuery * insert = nullptr; if (parse_res) - insert = typeid_cast(res.get()); + insert = res->as(); if (!(insert && insert->data)) { @@ -355,7 +355,7 @@ std::pair splitMultipartQuery(const std::string & queries, s ast = parseQueryAndMovePosition(parser, pos, end, "", true, 0); - ASTInsertQuery * insert = typeid_cast(ast.get()); + auto * insert = ast->as(); if (insert && insert->data) { diff --git a/dbms/src/Storages/AlterCommands.cpp b/dbms/src/Storages/AlterCommands.cpp index d3790aa3a19..27126b7bcdf 100644 --- a/dbms/src/Storages/AlterCommands.cpp +++ b/dbms/src/Storages/AlterCommands.cpp @@ -39,7 +39,7 @@ std::optional AlterCommand::parse(const ASTAlterCommand * command_ AlterCommand command; command.type = AlterCommand::ADD_COLUMN; - const auto & ast_col_decl = typeid_cast(*command_ast->col_decl); + const auto & ast_col_decl = command_ast->col_decl->as(); command.column_name = ast_col_decl.name; if (ast_col_decl.type) @@ -78,7 +78,7 @@ std::optional AlterCommand::parse(const ASTAlterCommand * command_ AlterCommand command; command.type = AlterCommand::MODIFY_COLUMN; - const auto & ast_col_decl = typeid_cast(*command_ast->col_decl); + const auto & ast_col_decl = command_ast->col_decl->as(); command.column_name = ast_col_decl.name; if (ast_col_decl.type) @@ -97,7 +97,7 @@ std::optional AlterCommand::parse(const ASTAlterCommand * command_ if (ast_col_decl.comment) { - const auto & ast_comment = typeid_cast(*ast_col_decl.comment); + const auto & ast_comment = ast_col_decl.comment->as(); command.comment = ast_comment.value.get(); } command.if_exists = command_ast->if_exists; @@ -109,7 +109,7 @@ std::optional AlterCommand::parse(const ASTAlterCommand * command_ AlterCommand command; command.type = COMMENT_COLUMN; command.column_name = *getIdentifierName(command_ast->column); - const auto & ast_comment = typeid_cast(*command_ast->comment); + const auto & ast_comment = command_ast->comment->as(); command.comment = ast_comment.value.get(); command.if_exists = command_ast->if_exists; return command; @@ -127,12 +127,12 @@ std::optional AlterCommand::parse(const ASTAlterCommand * command_ command.index_decl = command_ast->index_decl; command.type = AlterCommand::ADD_INDEX; - const auto & ast_index_decl = typeid_cast(*command_ast->index_decl); + const auto & ast_index_decl = command_ast->index_decl->as(); command.index_name = ast_index_decl.name; if (command_ast->index) - command.after_index_name = typeid_cast(*command_ast->index).name; + command.after_index_name = command_ast->index->as().name; command.if_not_exists = command_ast->if_not_exists; @@ -145,7 +145,7 @@ std::optional AlterCommand::parse(const ASTAlterCommand * command_ AlterCommand command; command.type = AlterCommand::DROP_INDEX; - command.index_name = typeid_cast(*(command_ast->index)).name; + command.index_name = command_ast->index->as().name; command.if_exists = command_ast->if_exists; return command; @@ -335,7 +335,7 @@ void AlterCommand::apply(ColumnsDescription & columns_description, IndicesDescri indices_description.indices.cend(), [this](const ASTPtr & index_ast) { - return typeid_cast(*index_ast).name == index_name; + return index_ast->as().name == index_name; })) { if (if_not_exists) @@ -354,7 +354,7 @@ void AlterCommand::apply(ColumnsDescription & columns_description, IndicesDescri indices_description.indices.end(), [this](const ASTPtr & index_ast) { - return typeid_cast(*index_ast).name == after_index_name; + return index_ast->as().name == after_index_name; }); if (insert_it == indices_description.indices.end()) @@ -373,7 +373,7 @@ void AlterCommand::apply(ColumnsDescription & columns_description, IndicesDescri indices_description.indices.end(), [this](const ASTPtr & index_ast) { - return typeid_cast(*index_ast).name == index_name; + return index_ast->as().name == index_name; }); if (erase_it == indices_description.indices.end()) diff --git a/dbms/src/Storages/ColumnsDescription.cpp b/dbms/src/Storages/ColumnsDescription.cpp index 92069098ebf..0f0636bc47f 100644 --- a/dbms/src/Storages/ColumnsDescription.cpp +++ b/dbms/src/Storages/ColumnsDescription.cpp @@ -158,7 +158,7 @@ void parseColumn(ReadBufferFromString & buf, ColumnsDescription & result, const String column_line; readEscapedStringUntilEOL(column_line, buf); ASTPtr ast = parseQuery(column_parser, column_line, "column parser", 0); - if (const ASTColumnDeclaration * col_ast = typeid_cast(ast.get())) + if (const auto * col_ast = ast->as()) { String column_name = col_ast->name; auto type = data_type_factory.get(col_ast->type); @@ -185,7 +185,7 @@ void parseColumn(ReadBufferFromString & buf, ColumnsDescription & result, const result.ordinary.emplace_back(column_name, std::move(type)); if (col_ast->comment) - if (auto comment_str = typeid_cast(*col_ast->comment).value.get(); !comment_str.empty()) + if (auto comment_str = col_ast->comment->as().value.get(); !comment_str.empty()) result.comments.emplace(column_name, std::move(comment_str)); if (col_ast->codec) diff --git a/dbms/src/Storages/Kafka/StorageKafka.cpp b/dbms/src/Storages/Kafka/StorageKafka.cpp index e1c6c19d418..5785e65d4d9 100644 --- a/dbms/src/Storages/Kafka/StorageKafka.cpp +++ b/dbms/src/Storages/Kafka/StorageKafka.cpp @@ -409,7 +409,7 @@ void registerStorageKafka(StorageFactory & factory) String brokers; if (args_count >= 1) { - auto ast = typeid_cast(engine_args[0].get()); + const auto * ast = engine_args[0]->as(); if (ast && ast->value.getType() == Field::Types::String) { brokers = safeGet(ast->value); @@ -429,7 +429,7 @@ void registerStorageKafka(StorageFactory & factory) if (args_count >= 2) { engine_args[1] = evaluateConstantExpressionAsLiteral(engine_args[1], args.local_context); - topic_list = static_cast(*engine_args[1]).value.safeGet(); + topic_list = engine_args[1]->as().value.safeGet(); } else if (kafka_settings.kafka_topic_list.changed) { @@ -447,7 +447,7 @@ void registerStorageKafka(StorageFactory & factory) if (args_count >= 3) { engine_args[2] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[2], args.local_context); - group = static_cast(*engine_args[2]).value.safeGet(); + group = engine_args[2]->as().value.safeGet(); } else if (kafka_settings.kafka_group_name.changed) { @@ -460,7 +460,7 @@ void registerStorageKafka(StorageFactory & factory) { engine_args[3] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[3], args.local_context); - auto ast = typeid_cast(engine_args[3].get()); + const auto * ast = engine_args[3]->as(); if (ast && ast->value.getType() == Field::Types::String) { format = safeGet(ast->value); @@ -481,7 +481,7 @@ void registerStorageKafka(StorageFactory & factory) { engine_args[4] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[4], args.local_context); - auto ast = typeid_cast(engine_args[4].get()); + const auto * ast = engine_args[4]->as(); String arg; if (ast && ast->value.getType() == Field::Types::String) { @@ -515,7 +515,7 @@ void registerStorageKafka(StorageFactory & factory) { engine_args[5] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[5], args.local_context); - auto ast = typeid_cast(engine_args[5].get()); + const auto * ast = engine_args[5]->as(); if (ast && ast->value.getType() == Field::Types::String) { schema = safeGet(ast->value); @@ -534,7 +534,7 @@ void registerStorageKafka(StorageFactory & factory) UInt64 num_consumers = 1; if (args_count >= 7) { - auto ast = typeid_cast(engine_args[6].get()); + const auto * ast = engine_args[6]->as(); if (ast && ast->value.getType() == Field::Types::UInt64) { num_consumers = safeGet(ast->value); @@ -553,7 +553,7 @@ void registerStorageKafka(StorageFactory & factory) UInt64 max_block_size = 0; if (args_count >= 8) { - auto ast = typeid_cast(engine_args[7].get()); + const auto * ast = engine_args[7]->as(); if (ast && ast->value.getType() == Field::Types::UInt64) { max_block_size = static_cast(safeGet(ast->value)); @@ -572,7 +572,7 @@ void registerStorageKafka(StorageFactory & factory) size_t skip_broken = 0; if (args_count >= 9) { - auto ast = typeid_cast(engine_args[8].get()); + const auto * ast = engine_args[8]->as(); if (ast && ast->value.getType() == Field::Types::UInt64) { skip_broken = static_cast(safeGet(ast->value)); diff --git a/dbms/src/Storages/MergeTree/KeyCondition.cpp b/dbms/src/Storages/MergeTree/KeyCondition.cpp index b64920d6233..61ae85549c9 100644 --- a/dbms/src/Storages/MergeTree/KeyCondition.cpp +++ b/dbms/src/Storages/MergeTree/KeyCondition.cpp @@ -284,7 +284,7 @@ KeyCondition::KeyCondition( Block block_with_constants = getBlockWithConstants(query_info.query, query_info.syntax_analyzer_result, context); /// Trasform WHERE section to Reverse Polish notation - const ASTSelectQuery & select = typeid_cast(*query_info.query); + const auto & select = query_info.query->as(); if (select.where_expression) { traverseAST(select.where_expression, context, block_with_constants); @@ -321,7 +321,7 @@ static bool getConstant(const ASTPtr & expr, Block & block_with_constants, Field { String column_name = expr->getColumnName(); - if (const ASTLiteral * lit = typeid_cast(expr.get())) + if (const auto * lit = expr->as()) { /// By default block_with_constants has only one column named "_dummy". /// If block contains only constants it's may not be preprocessed by @@ -370,11 +370,11 @@ void KeyCondition::traverseAST(const ASTPtr & node, const Context & context, Blo { RPNElement element; - if (ASTFunction * func = typeid_cast(&*node)) + if (auto * func = node->as()) { if (operatorFromAST(func, element)) { - auto & args = typeid_cast(*func->arguments).children; + auto & args = func->arguments->children; for (size_t i = 0, size = args.size(); i < size; ++i) { traverseAST(args[i], context, block_with_constants); @@ -486,7 +486,7 @@ bool KeyCondition::tryPrepareSetIndex( } }; - const ASTFunction * left_arg_tuple = typeid_cast(left_arg.get()); + const auto * left_arg_tuple = left_arg->as(); if (left_arg_tuple && left_arg_tuple->name == "tuple") { const auto & tuple_elements = left_arg_tuple->arguments->children; @@ -502,7 +502,7 @@ bool KeyCondition::tryPrepareSetIndex( const ASTPtr & right_arg = args[1]; PreparedSetKey set_key; - if (typeid_cast(right_arg.get()) || typeid_cast(right_arg.get())) + if (right_arg->as() || right_arg->as()) set_key = PreparedSetKey::forSubquery(*right_arg); else set_key = PreparedSetKey::forLiteral(*right_arg, data_types); @@ -574,7 +574,7 @@ bool KeyCondition::isKeyPossiblyWrappedByMonotonicFunctionsImpl( return true; } - if (const ASTFunction * func = typeid_cast(node.get())) + if (const auto * func = node->as()) { const auto & args = func->arguments->children; if (args.size() != 1) @@ -620,9 +620,9 @@ bool KeyCondition::atomFromAST(const ASTPtr & node, const Context & context, Blo */ Field const_value; DataTypePtr const_type; - if (const ASTFunction * func = typeid_cast(node.get())) + if (const auto * func = node->as()) { - const ASTs & args = typeid_cast(*func->arguments).children; + const ASTs & args = func->arguments->children; if (args.size() != 2) return false; @@ -737,7 +737,7 @@ bool KeyCondition::operatorFromAST(const ASTFunction * func, RPNElement & out) /** Also a special function `indexHint` - works as if instead of calling a function there are just parentheses * (or, the same thing - calling the function `and` from one argument). */ - const ASTs & args = typeid_cast(*func->arguments).children; + const ASTs & args = func->arguments->children; if (func->name == "not") { diff --git a/dbms/src/Storages/MergeTree/MergeTreeData.cpp b/dbms/src/Storages/MergeTree/MergeTreeData.cpp index fe0a73705b0..adc83bc338b 100644 --- a/dbms/src/Storages/MergeTree/MergeTreeData.cpp +++ b/dbms/src/Storages/MergeTree/MergeTreeData.cpp @@ -405,7 +405,7 @@ ASTPtr MergeTreeData::extractKeyExpressionList(const ASTPtr & node) if (!node) return std::make_shared(); - const ASTFunction * expr_func = typeid_cast(node.get()); + const auto * expr_func = node->as(); if (expr_func && expr_func->name == "tuple") { @@ -1174,10 +1174,10 @@ void MergeTreeData::createConvertExpression(const DataPartPtr & part, const Name /// Remove old indices std::set new_indices_set; for (const auto & index_decl : new_indices) - new_indices_set.emplace(dynamic_cast(*index_decl.get()).name); + new_indices_set.emplace(index_decl->as().name); for (const auto & index_decl : old_indices) { - const auto & index = dynamic_cast(*index_decl.get()); + const auto & index = index_decl->as(); if (!new_indices_set.count(index.name)) { out_rename_map["skp_idx_" + index.name + ".idx"] = ""; @@ -2219,9 +2219,8 @@ void MergeTreeData::freezePartition(const ASTPtr & partition_ast, const String & if (format_version < MERGE_TREE_DATA_MIN_FORMAT_VERSION_WITH_CUSTOM_PARTITIONING) { - const auto & partition = dynamic_cast(*partition_ast); /// Month-partitioning specific - partition value can represent a prefix of the partition to freeze. - if (const auto * partition_lit = dynamic_cast(partition.value.get())) + if (const auto * partition_lit = partition_ast->as().value->as()) prefix = partition_lit->value.getType() == Field::Types::UInt64 ? toString(partition_lit->value.get()) : partition_lit->value.safeGet(); @@ -2276,7 +2275,7 @@ size_t MergeTreeData::getPartitionSize(const std::string & partition_id) const String MergeTreeData::getPartitionIDFromQuery(const ASTPtr & ast, const Context & context) { - const auto & partition_ast = typeid_cast(*ast); + const auto & partition_ast = ast->as(); if (!partition_ast.value) return partition_ast.id; @@ -2284,7 +2283,7 @@ String MergeTreeData::getPartitionIDFromQuery(const ASTPtr & ast, const Context if (format_version < MERGE_TREE_DATA_MIN_FORMAT_VERSION_WITH_CUSTOM_PARTITIONING) { /// Month-partitioning specific - partition ID can be passed in the partition value. - const auto * partition_lit = typeid_cast(partition_ast.value.get()); + const auto * partition_lit = partition_ast.value->as(); if (partition_lit && partition_lit->value.getType() == Field::Types::String) { String partition_id = partition_lit->value.get(); @@ -2502,7 +2501,7 @@ bool MergeTreeData::isPrimaryOrMinMaxKeyColumnPossiblyWrappedInFunctions(const A if (column_name == name) return true; - if (const ASTFunction * func = typeid_cast(node.get())) + if (const auto * func = node->as()) if (func->arguments->children.size() == 1) return isPrimaryOrMinMaxKeyColumnPossiblyWrappedInFunctions(func->arguments->children.front()); @@ -2514,7 +2513,7 @@ bool MergeTreeData::mayBenefitFromIndexForIn(const ASTPtr & left_in_operand) con /// Make sure that the left side of the IN operator contain part of the key. /// If there is a tuple on the left side of the IN operator, at least one item of the tuple /// must be part of the key (probably wrapped by a chain of some acceptable functions). - const ASTFunction * left_in_operand_tuple = typeid_cast(left_in_operand.get()); + const auto * left_in_operand_tuple = left_in_operand->as(); if (left_in_operand_tuple && left_in_operand_tuple->name == "tuple") { for (const auto & item : left_in_operand_tuple->arguments->children) diff --git a/dbms/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp b/dbms/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp index 759980a4fab..3f8c3aa3006 100644 --- a/dbms/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp +++ b/dbms/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp @@ -129,7 +129,7 @@ static RelativeSize convertAbsoluteSampleSizeToRelative(const ASTPtr & node, siz if (approx_total_rows == 0) return 1; - const ASTSampleRatio & node_sample = typeid_cast(*node); + const auto & node_sample = node->as(); auto absolute_sample_size = node_sample.ratio.numerator / node_sample.ratio.denominator; return std::min(RelativeSize(1), RelativeSize(absolute_sample_size) / RelativeSize(approx_total_rows)); @@ -287,7 +287,7 @@ BlockInputStreams MergeTreeDataSelectExecutor::readFromParts( RelativeSize relative_sample_size = 0; RelativeSize relative_sample_offset = 0; - ASTSelectQuery & select = typeid_cast(*query_info.query); + const auto & select = query_info.query->as(); auto select_sample_size = select.sample_size(); auto select_sample_offset = select.sample_offset(); @@ -295,8 +295,8 @@ BlockInputStreams MergeTreeDataSelectExecutor::readFromParts( if (select_sample_size) { relative_sample_size.assign( - typeid_cast(*select_sample_size).ratio.numerator, - typeid_cast(*select_sample_size).ratio.denominator); + select_sample_size->as().ratio.numerator, + select_sample_size->as().ratio.denominator); if (relative_sample_size < 0) throw Exception("Negative sample size", ErrorCodes::ARGUMENT_OUT_OF_BOUND); @@ -304,8 +304,8 @@ BlockInputStreams MergeTreeDataSelectExecutor::readFromParts( relative_sample_offset = 0; if (select_sample_offset) relative_sample_offset.assign( - typeid_cast(*select_sample_offset).ratio.numerator, - typeid_cast(*select_sample_offset).ratio.denominator); + select_sample_offset->as().ratio.numerator, + select_sample_offset->as().ratio.denominator); if (relative_sample_offset < 0) throw Exception("Negative sample offset", ErrorCodes::ARGUMENT_OUT_OF_BOUND); diff --git a/dbms/src/Storages/MergeTree/MergeTreeSetSkippingIndex.cpp b/dbms/src/Storages/MergeTree/MergeTreeSetSkippingIndex.cpp index 742d3971930..9a014676a1f 100644 --- a/dbms/src/Storages/MergeTree/MergeTreeSetSkippingIndex.cpp +++ b/dbms/src/Storages/MergeTree/MergeTreeSetSkippingIndex.cpp @@ -225,7 +225,7 @@ SetIndexCondition::SetIndexCondition( key_columns.insert(name); } - const ASTSelectQuery & select = typeid_cast(*query.query); + const auto & select = query.query->as(); /// Replace logical functions with bit functions. /// Working with UInt8: last bit = can be true, previous = can be false. @@ -298,8 +298,7 @@ void SetIndexCondition::traverseAST(ASTPtr & node) const { if (operatorFromAST(node)) { - auto * func = typeid_cast(&*node); - auto & args = typeid_cast(*func->arguments).children; + auto & args = node->as()->arguments->children; for (auto & arg : args) traverseAST(arg); @@ -314,13 +313,13 @@ bool SetIndexCondition::atomFromAST(ASTPtr & node) const { /// Function, literal or column - if (typeid_cast(node.get())) + if (node->as()) return true; - if (const auto * identifier = typeid_cast(node.get())) + if (const auto * identifier = node->as()) return key_columns.count(identifier->getColumnName()) != 0; - if (auto * func = typeid_cast(node.get())) + if (auto * func = node->as()) { if (key_columns.count(func->getColumnName())) { @@ -329,7 +328,7 @@ bool SetIndexCondition::atomFromAST(ASTPtr & node) const return true; } - ASTs & args = typeid_cast(*func->arguments).children; + auto & args = func->arguments->children; for (auto & arg : args) if (!atomFromAST(arg)) @@ -344,11 +343,11 @@ bool SetIndexCondition::atomFromAST(ASTPtr & node) const bool SetIndexCondition::operatorFromAST(ASTPtr & node) const { /// Functions AND, OR, NOT. Replace with bit*. - auto * func = typeid_cast(&*node); + auto * func = node->as(); if (!func) return false; - ASTs & args = typeid_cast(*func->arguments).children; + auto & args = func->arguments->children; if (func->name == "not") { @@ -419,12 +418,12 @@ static bool checkAtomName(const String & name) bool SetIndexCondition::checkASTUseless(const ASTPtr &node, bool atomic) const { - if (const auto * func = typeid_cast(node.get())) + if (const auto * func = node->as()) { if (key_columns.count(func->getColumnName())) return false; - const ASTs & args = typeid_cast(*func->arguments).children; + const ASTs & args = func->arguments->children; if (func->name == "and" || func->name == "indexHint") return checkASTUseless(args[0], atomic) && checkASTUseless(args[1], atomic); @@ -438,9 +437,9 @@ bool SetIndexCondition::checkASTUseless(const ASTPtr &node, bool atomic) const return std::any_of(args.begin(), args.end(), [this, &atomic](const auto & arg) { return checkASTUseless(arg, atomic); }); } - else if (const auto * literal = typeid_cast(node.get())) + else if (const auto * literal = node->as()) return !atomic && literal->value.get(); - else if (const auto * identifier = typeid_cast(node.get())) + else if (const auto * identifier = node->as()) return key_columns.find(identifier->getColumnName()) == key_columns.end(); else return true; @@ -476,8 +475,7 @@ std::unique_ptr setIndexCreator( if (!node->type->arguments || node->type->arguments->children.size() != 1) throw Exception("Set index must have exactly one argument.", ErrorCodes::INCORRECT_QUERY); else if (node->type->arguments->children.size() == 1) - max_rows = typeid_cast( - *node->type->arguments->children[0]).value.get(); + max_rows = node->type->arguments->children[0]->as().value.get(); ASTPtr expr_list = MergeTreeData::extractKeyExpressionList(node->expr->clone()); diff --git a/dbms/src/Storages/MergeTree/MergeTreeWhereOptimizer.cpp b/dbms/src/Storages/MergeTree/MergeTreeWhereOptimizer.cpp index 89a499b710a..fc0b13ebbc0 100644 --- a/dbms/src/Storages/MergeTree/MergeTreeWhereOptimizer.cpp +++ b/dbms/src/Storages/MergeTree/MergeTreeWhereOptimizer.cpp @@ -44,9 +44,8 @@ MergeTreeWhereOptimizer::MergeTreeWhereOptimizer( first_primary_key_column = data.primary_key_columns[0]; calculateColumnSizes(data, queried_columns); - auto & select = typeid_cast(*query_info.query); - determineArrayJoinedNames(select); - optimize(select); + determineArrayJoinedNames(query_info.query->as()); + optimize(query_info.query->as()); } @@ -66,7 +65,7 @@ static void collectIdentifiersNoSubqueries(const ASTPtr & ast, NameSet & set) if (auto opt_name = getIdentifierName(ast)) return (void)set.insert(*opt_name); - if (typeid_cast(ast.get())) + if (ast->as()) return; for (const auto & child : ast->children) @@ -75,7 +74,7 @@ static void collectIdentifiersNoSubqueries(const ASTPtr & ast, NameSet & set) void MergeTreeWhereOptimizer::analyzeImpl(Conditions & res, const ASTPtr & node) const { - if (const auto func_and = typeid_cast(node.get()); func_and && func_and->name == "and") + if (const auto * func_and = node->as(); func_and && func_and->name == "and") { for (const auto & elem : func_and->arguments->children) analyzeImpl(res, elem); @@ -219,7 +218,7 @@ UInt64 MergeTreeWhereOptimizer::getIdentifiersColumnSize(const NameSet & identif bool MergeTreeWhereOptimizer::isConditionGood(const ASTPtr & condition) const { - const auto function = typeid_cast(condition.get()); + const auto * function = condition->as(); if (!function) return false; @@ -232,13 +231,13 @@ bool MergeTreeWhereOptimizer::isConditionGood(const ASTPtr & condition) const auto right_arg = function->arguments->children.back().get(); /// try to ensure left_arg points to ASTIdentifier - if (!isIdentifier(left_arg) && isIdentifier(right_arg)) + if (!left_arg->as() && right_arg->as()) std::swap(left_arg, right_arg); - if (isIdentifier(left_arg)) + if (left_arg->as()) { /// condition may be "good" if only right_arg is a constant and its value is outside the threshold - if (const auto literal = typeid_cast(right_arg)) + if (const auto * literal = right_arg->as()) { const auto & field = literal->value; const auto type = field.getType(); @@ -268,7 +267,7 @@ bool MergeTreeWhereOptimizer::isConditionGood(const ASTPtr & condition) const bool MergeTreeWhereOptimizer::hasPrimaryKeyAtoms(const ASTPtr & ast) const { - if (const auto func = typeid_cast(ast.get())) + if (const auto * func = ast->as()) { const auto & args = func->arguments->children; @@ -288,7 +287,7 @@ bool MergeTreeWhereOptimizer::hasPrimaryKeyAtoms(const ASTPtr & ast) const bool MergeTreeWhereOptimizer::isPrimaryKeyAtom(const ASTPtr & ast) const { - if (const auto func = typeid_cast(ast.get())) + if (const auto * func = ast->as()) { if (!KeyCondition::atom_map.count(func->name)) return false; @@ -314,7 +313,7 @@ bool MergeTreeWhereOptimizer::isConstant(const ASTPtr & expr) const { const auto column_name = expr->getColumnName(); - if (typeid_cast(expr.get()) + if (expr->as() || (block_with_constants.has(column_name) && block_with_constants.getByName(column_name).column->isColumnConst())) return true; @@ -334,7 +333,7 @@ bool MergeTreeWhereOptimizer::isSubsetOfTableColumns(const NameSet & identifiers bool MergeTreeWhereOptimizer::cannotBeMoved(const ASTPtr & ptr) const { - if (const auto function_ptr = typeid_cast(ptr.get())) + if (const auto * function_ptr = ptr->as()) { /// disallow arrayJoin expressions to be moved to PREWHERE for now if ("arrayJoin" == function_ptr->name) diff --git a/dbms/src/Storages/MergeTree/registerStorageMergeTree.cpp b/dbms/src/Storages/MergeTree/registerStorageMergeTree.cpp index a64f376e3de..b08300bb039 100644 --- a/dbms/src/Storages/MergeTree/registerStorageMergeTree.cpp +++ b/dbms/src/Storages/MergeTree/registerStorageMergeTree.cpp @@ -37,7 +37,7 @@ namespace ErrorCodes */ static Names extractColumnNames(const ASTPtr & node) { - const ASTFunction * expr_func = typeid_cast(&*node); + const auto * expr_func = node->as(); if (expr_func && expr_func->name == "tuple") { @@ -476,7 +476,7 @@ static StoragePtr create(const StorageFactory::Arguments & args) if (replicated) { - auto ast = typeid_cast(engine_args[0].get()); + const auto * ast = engine_args[0]->as(); if (ast && ast->value.getType() == Field::Types::String) zookeeper_path = safeGet(ast->value); else @@ -484,7 +484,7 @@ static StoragePtr create(const StorageFactory::Arguments & args) "Path in ZooKeeper must be a string literal" + getMergeTreeVerboseHelp(is_extended_storage_def), ErrorCodes::BAD_ARGUMENTS); - ast = typeid_cast(engine_args[1].get()); + ast = engine_args[1]->as(); if (ast && ast->value.getType() == Field::Types::String) replica_name = safeGet(ast->value); else @@ -512,7 +512,7 @@ static StoragePtr create(const StorageFactory::Arguments & args) else if (merging_params.mode == MergeTreeData::MergingParams::Replacing) { /// If the last element is not index_granularity or replica_name (a literal), then this is the name of the version column. - if (!engine_args.empty() && !typeid_cast(engine_args.back().get())) + if (!engine_args.empty() && !engine_args.back()->as()) { if (!getIdentifierName(engine_args.back(), merging_params.version_column)) throw Exception( @@ -525,7 +525,7 @@ static StoragePtr create(const StorageFactory::Arguments & args) else if (merging_params.mode == MergeTreeData::MergingParams::Summing) { /// If the last element is not index_granularity or replica_name (a literal), then this is a list of summable columns. - if (!engine_args.empty() && !typeid_cast(engine_args.back().get())) + if (!engine_args.empty() && !engine_args.back()->as()) { merging_params.columns_to_sum = extractColumnNames(engine_args.back()); engine_args.pop_back(); @@ -537,7 +537,7 @@ static StoragePtr create(const StorageFactory::Arguments & args) String error_msg = "Last parameter of GraphiteMergeTree must be name (in single quotes) of element in configuration file with Graphite options"; error_msg += getMergeTreeVerboseHelp(is_extended_storage_def); - if (auto ast = typeid_cast(engine_args.back().get())) + if (const auto * ast = engine_args.back()->as()) { if (ast->value.getType() != Field::Types::String) throw Exception(error_msg, ErrorCodes::BAD_ARGUMENTS); @@ -618,7 +618,7 @@ static StoragePtr create(const StorageFactory::Arguments & args) order_by_ast = engine_args[1]; - auto ast = typeid_cast(engine_args.back().get()); + const auto * ast = engine_args.back()->as(); if (ast && ast->value.getType() == Field::Types::UInt64) storage_settings.index_granularity = safeGet(ast->value); else diff --git a/dbms/src/Storages/MutationCommands.cpp b/dbms/src/Storages/MutationCommands.cpp index 6ba9c23a257..349ecb66980 100644 --- a/dbms/src/Storages/MutationCommands.cpp +++ b/dbms/src/Storages/MutationCommands.cpp @@ -35,7 +35,7 @@ std::optional MutationCommand::parse(ASTAlterCommand * command) res.predicate = command->predicate; for (const ASTPtr & assignment_ast : command->update_assignments->children) { - const auto & assignment = typeid_cast(*assignment_ast); + const auto & assignment = assignment_ast->as(); auto insertion = res.column_to_update_expression.emplace(assignment.column_name, assignment.expression); if (!insertion.second) throw Exception("Multiple assignments in the single statement to column `" + assignment.column_name + "`", @@ -71,7 +71,7 @@ void MutationCommands::readText(ReadBuffer & in) ParserAlterCommandList p_alter_commands; auto commands_ast = parseQuery( p_alter_commands, commands_str.data(), commands_str.data() + commands_str.length(), "mutation commands list", 0); - for (ASTAlterCommand * command_ast : typeid_cast(*commands_ast).commands) + for (ASTAlterCommand * command_ast : commands_ast->as().commands) { auto command = MutationCommand::parse(command_ast); if (!command) diff --git a/dbms/src/Storages/StorageBuffer.cpp b/dbms/src/Storages/StorageBuffer.cpp index 5487db29703..c9a1e2fe52b 100644 --- a/dbms/src/Storages/StorageBuffer.cpp +++ b/dbms/src/Storages/StorageBuffer.cpp @@ -713,17 +713,17 @@ void registerStorageBuffer(StorageFactory & factory) engine_args[0] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[0], args.local_context); engine_args[1] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[1], args.local_context); - String destination_database = static_cast(*engine_args[0]).value.safeGet(); - String destination_table = static_cast(*engine_args[1]).value.safeGet(); + String destination_database = engine_args[0]->as().value.safeGet(); + String destination_table = engine_args[1]->as().value.safeGet(); - UInt64 num_buckets = applyVisitor(FieldVisitorConvertToNumber(), typeid_cast(*engine_args[2]).value); + UInt64 num_buckets = applyVisitor(FieldVisitorConvertToNumber(), engine_args[2]->as().value); - Int64 min_time = applyVisitor(FieldVisitorConvertToNumber(), typeid_cast(*engine_args[3]).value); - Int64 max_time = applyVisitor(FieldVisitorConvertToNumber(), typeid_cast(*engine_args[4]).value); - UInt64 min_rows = applyVisitor(FieldVisitorConvertToNumber(), typeid_cast(*engine_args[5]).value); - UInt64 max_rows = applyVisitor(FieldVisitorConvertToNumber(), typeid_cast(*engine_args[6]).value); - UInt64 min_bytes = applyVisitor(FieldVisitorConvertToNumber(), typeid_cast(*engine_args[7]).value); - UInt64 max_bytes = applyVisitor(FieldVisitorConvertToNumber(), typeid_cast(*engine_args[8]).value); + Int64 min_time = applyVisitor(FieldVisitorConvertToNumber(), engine_args[3]->as().value); + Int64 max_time = applyVisitor(FieldVisitorConvertToNumber(), engine_args[4]->as().value); + UInt64 min_rows = applyVisitor(FieldVisitorConvertToNumber(), engine_args[5]->as().value); + UInt64 max_rows = applyVisitor(FieldVisitorConvertToNumber(), engine_args[6]->as().value); + UInt64 min_bytes = applyVisitor(FieldVisitorConvertToNumber(), engine_args[7]->as().value); + UInt64 max_bytes = applyVisitor(FieldVisitorConvertToNumber(), engine_args[8]->as().value); return StorageBuffer::create( args.table_name, args.columns, diff --git a/dbms/src/Storages/StorageDictionary.cpp b/dbms/src/Storages/StorageDictionary.cpp index 5aa2ea6b329..2c8c76005d4 100644 --- a/dbms/src/Storages/StorageDictionary.cpp +++ b/dbms/src/Storages/StorageDictionary.cpp @@ -101,7 +101,7 @@ void registerStorageDictionary(StorageFactory & factory) ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); args.engine_args[0] = evaluateConstantExpressionOrIdentifierAsLiteral(args.engine_args[0], args.local_context); - String dictionary_name = typeid_cast(*args.engine_args[0]).value.safeGet(); + String dictionary_name = args.engine_args[0]->as().value.safeGet(); return StorageDictionary::create( args.table_name, args.columns, args.context, args.attach, dictionary_name); diff --git a/dbms/src/Storages/StorageDistributed.cpp b/dbms/src/Storages/StorageDistributed.cpp index 582dd976c0d..27963a914b2 100644 --- a/dbms/src/Storages/StorageDistributed.cpp +++ b/dbms/src/Storages/StorageDistributed.cpp @@ -75,9 +75,9 @@ ASTPtr rewriteSelectQuery(const ASTPtr & query, const std::string & database, co { auto modified_query_ast = query->clone(); if (table_function_ptr) - typeid_cast(*modified_query_ast).addTableFunction(table_function_ptr); + modified_query_ast->as().addTableFunction(table_function_ptr); else - typeid_cast(*modified_query_ast).replaceDatabaseAndTable(database, table); + modified_query_ast->as().replaceDatabaseAndTable(database, table); return modified_query_ast; } @@ -468,7 +468,7 @@ void StorageDistributed::ClusterNodeData::shutdownAndDropAllData() /// using constraints from "WHERE" condition, otherwise returns `nullptr` ClusterPtr StorageDistributed::skipUnusedShards(ClusterPtr cluster, const SelectQueryInfo & query_info) { - const auto & select = typeid_cast(*query_info.query); + const auto & select = query_info.query->as(); if (!select.where_expression) { @@ -528,8 +528,8 @@ void registerStorageDistributed(StorageFactory & factory) engine_args[1] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[1], args.local_context); engine_args[2] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[2], args.local_context); - String remote_database = static_cast(*engine_args[1]).value.safeGet(); - String remote_table = static_cast(*engine_args[2]).value.safeGet(); + String remote_database = engine_args[1]->as().value.safeGet(); + String remote_table = engine_args[2]->as().value.safeGet(); const auto & sharding_key = engine_args.size() == 4 ? engine_args[3] : nullptr; diff --git a/dbms/src/Storages/StorageFile.cpp b/dbms/src/Storages/StorageFile.cpp index 0c221cf3393..c6eb867f32a 100644 --- a/dbms/src/Storages/StorageFile.cpp +++ b/dbms/src/Storages/StorageFile.cpp @@ -291,7 +291,7 @@ void registerStorageFile(StorageFactory & factory) ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); engine_args[0] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[0], args.local_context); - String format_name = static_cast(*engine_args[0]).value.safeGet(); + String format_name = engine_args[0]->as().value.safeGet(); int source_fd = -1; String source_path; @@ -311,7 +311,7 @@ void registerStorageFile(StorageFactory & factory) throw Exception("Unknown identifier '" + *opt_name + "' in second arg of File storage constructor", ErrorCodes::UNKNOWN_IDENTIFIER); } - else if (const ASTLiteral * literal = typeid_cast(engine_args[1].get())) + else if (const auto * literal = engine_args[1]->as()) { auto type = literal->value.getType(); if (type == Field::Types::Int64) diff --git a/dbms/src/Storages/StorageHDFS.cpp b/dbms/src/Storages/StorageHDFS.cpp index 4f6cf35c09e..aec846f1a58 100644 --- a/dbms/src/Storages/StorageHDFS.cpp +++ b/dbms/src/Storages/StorageHDFS.cpp @@ -163,11 +163,11 @@ void registerStorageHDFS(StorageFactory & factory) engine_args[0] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[0], args.local_context); - String url = static_cast(*engine_args[0]).value.safeGet(); + String url = engine_args[0]->as().value.safeGet(); engine_args[1] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[1], args.local_context); - String format_name = static_cast(*engine_args[1]).value.safeGet(); + String format_name = engine_args[1]->as().value.safeGet(); return StorageHDFS::create(url, args.table_name, format_name, args.columns, args.context); }); diff --git a/dbms/src/Storages/StorageMaterializedView.cpp b/dbms/src/Storages/StorageMaterializedView.cpp index b19dd53ff49..fd63053c78a 100644 --- a/dbms/src/Storages/StorageMaterializedView.cpp +++ b/dbms/src/Storages/StorageMaterializedView.cpp @@ -48,14 +48,14 @@ static void extractDependentTable(ASTSelectQuery & query, String & select_databa else select_database_name = db_and_table->database; } - else if (auto ast_select = typeid_cast(subquery.get())) + else if (auto * ast_select = subquery->as()) { if (ast_select->list_of_selects->children.size() != 1) throw Exception("UNION is not supported for MATERIALIZED VIEW", ErrorCodes::QUERY_IS_NOT_SUPPORTED_IN_MATERIALIZED_VIEW); auto & inner_query = ast_select->list_of_selects->children.at(0); - extractDependentTable(typeid_cast(*inner_query), select_database_name, select_table_name); + extractDependentTable(inner_query->as(), select_database_name, select_table_name); } else throw Exception("Logical error while creating StorageMaterializedView." @@ -73,14 +73,14 @@ static void checkAllowedQueries(const ASTSelectQuery & query) if (!subquery) return; - if (auto ast_select = typeid_cast(subquery.get())) + if (const auto * ast_select = subquery->as()) { if (ast_select->list_of_selects->children.size() != 1) throw Exception("UNION is not supported for MATERIALIZED VIEW", ErrorCodes::QUERY_IS_NOT_SUPPORTED_IN_MATERIALIZED_VIEW); const auto & inner_query = ast_select->list_of_selects->children.at(0); - checkAllowedQueries(typeid_cast(*inner_query)); + checkAllowedQueries(inner_query->as()); } } @@ -110,7 +110,7 @@ StorageMaterializedView::StorageMaterializedView( inner_query = query.select->list_of_selects->children.at(0); - ASTSelectQuery & select_query = typeid_cast(*inner_query); + auto & select_query = inner_query->as(); extractDependentTable(select_query, select_database_name, select_table_name); checkAllowedQueries(select_query); diff --git a/dbms/src/Storages/StorageMaterializedView.h b/dbms/src/Storages/StorageMaterializedView.h index 8c2657b484b..665f6243a32 100644 --- a/dbms/src/Storages/StorageMaterializedView.h +++ b/dbms/src/Storages/StorageMaterializedView.h @@ -2,16 +2,13 @@ #include +#include #include namespace DB { -class IAST; // XXX: should include full class - for proper use inside inline methods -using ASTPtr = std::shared_ptr; - - class StorageMaterializedView : public ext::shared_ptr_helper, public IStorage { public: diff --git a/dbms/src/Storages/StorageMerge.cpp b/dbms/src/Storages/StorageMerge.cpp index 4521083bc03..225ae9c2a07 100644 --- a/dbms/src/Storages/StorageMerge.cpp +++ b/dbms/src/Storages/StorageMerge.cpp @@ -289,7 +289,7 @@ BlockInputStreams StorageMerge::createSourceStreams(const SelectQueryInfo & quer } else if (processed_stage > storage->getQueryProcessingStage(modified_context)) { - typeid_cast(modified_query_info.query.get())->replaceDatabaseAndTable(source_database, storage->getTableName()); + modified_query_info.query->as()->replaceDatabaseAndTable(source_database, storage->getTableName()); /// Maximum permissible parallelism is streams_num modified_context.getSettingsRef().max_threads = UInt64(streams_num); @@ -369,7 +369,7 @@ StorageMerge::StorageListWithLocks StorageMerge::getSelectedTables(const ASTPtr { StoragePtr storage = iterator->table(); - if (query && typeid_cast(query.get())->prewhere_expression && !storage->supportsPrewhere()) + if (query && query->as()->prewhere_expression && !storage->supportsPrewhere()) throw Exception("Storage " + storage->getName() + " doesn't support PREWHERE.", ErrorCodes::ILLEGAL_PREWHERE); if (storage.get() != this) @@ -440,7 +440,7 @@ void StorageMerge::convertingSourceStream(const Block & header, const Context & Block before_block_header = source_stream->getHeader(); source_stream = std::make_shared(context, source_stream, header, ConvertingBlockInputStream::MatchColumnsMode::Name); - ASTPtr where_expression = typeid_cast(query.get())->where_expression; + auto where_expression = query->as()->where_expression; if (!where_expression) return; @@ -491,8 +491,8 @@ void registerStorageMerge(StorageFactory & factory) engine_args[0] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[0], args.local_context); engine_args[1] = evaluateConstantExpressionAsLiteral(engine_args[1], args.local_context); - String source_database = static_cast(*engine_args[0]).value.safeGet(); - String table_name_regexp = static_cast(*engine_args[1]).value.safeGet(); + String source_database = engine_args[0]->as().value.safeGet(); + String table_name_regexp = engine_args[1]->as().value.safeGet(); return StorageMerge::create( args.table_name, args.columns, diff --git a/dbms/src/Storages/StorageMergeTree.cpp b/dbms/src/Storages/StorageMergeTree.cpp index 856976ca35d..fd356e75e8f 100644 --- a/dbms/src/Storages/StorageMergeTree.cpp +++ b/dbms/src/Storages/StorageMergeTree.cpp @@ -235,7 +235,7 @@ void StorageMergeTree::alter( IDatabase::ASTModifier storage_modifier = [&] (IAST & ast) { - auto & storage_ast = typeid_cast(ast); + auto & storage_ast = ast.as(); if (new_order_by_ast.get() != data.order_by_ast.get()) storage_ast.set(storage_ast.order_by, new_order_by_ast); @@ -941,7 +941,7 @@ void StorageMergeTree::attachPartition(const ASTPtr & partition, bool attach_par String partition_id; if (attach_part) - partition_id = typeid_cast(*partition).value.safeGet(); + partition_id = partition->as().value.safeGet(); else partition_id = data.getPartitionIDFromQuery(partition, context); diff --git a/dbms/src/Storages/StorageMySQL.cpp b/dbms/src/Storages/StorageMySQL.cpp index 127caefcd3b..cd700529b79 100644 --- a/dbms/src/Storages/StorageMySQL.cpp +++ b/dbms/src/Storages/StorageMySQL.cpp @@ -199,21 +199,21 @@ void registerStorageMySQL(StorageFactory & factory) engine_args[i] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[i], args.local_context); /// 3306 is the default MySQL port. - auto parsed_host_port = parseAddress(static_cast(*engine_args[0]).value.safeGet(), 3306); + auto parsed_host_port = parseAddress(engine_args[0]->as().value.safeGet(), 3306); - const String & remote_database = static_cast(*engine_args[1]).value.safeGet(); - const String & remote_table = static_cast(*engine_args[2]).value.safeGet(); - const String & username = static_cast(*engine_args[3]).value.safeGet(); - const String & password = static_cast(*engine_args[4]).value.safeGet(); + const String & remote_database = engine_args[1]->as().value.safeGet(); + const String & remote_table = engine_args[2]->as().value.safeGet(); + const String & username = engine_args[3]->as().value.safeGet(); + const String & password = engine_args[4]->as().value.safeGet(); mysqlxx::Pool pool(remote_database, parsed_host_port.first, username, password, parsed_host_port.second); bool replace_query = false; std::string on_duplicate_clause; if (engine_args.size() >= 6) - replace_query = static_cast(*engine_args[5]).value.safeGet() > 0; + replace_query = engine_args[5]->as().value.safeGet(); if (engine_args.size() == 7) - on_duplicate_clause = static_cast(*engine_args[6]).value.safeGet(); + on_duplicate_clause = engine_args[6]->as().value.safeGet(); if (replace_query && !on_duplicate_clause.empty()) throw Exception( diff --git a/dbms/src/Storages/StorageReplicatedMergeTree.cpp b/dbms/src/Storages/StorageReplicatedMergeTree.cpp index 235e9ee1cb0..045d25c61b0 100644 --- a/dbms/src/Storages/StorageReplicatedMergeTree.cpp +++ b/dbms/src/Storages/StorageReplicatedMergeTree.cpp @@ -454,7 +454,7 @@ void StorageReplicatedMergeTree::setTableStructure(ColumnsDescription new_column storage_modifier = [&](IAST & ast) { - auto & storage_ast = typeid_cast(ast); + auto & storage_ast = ast.as(); if (!storage_ast.order_by) throw Exception( @@ -3526,7 +3526,7 @@ void StorageReplicatedMergeTree::attachPartition(const ASTPtr & partition, bool String partition_id; if (attach_part) - partition_id = typeid_cast(*partition).value.safeGet(); + partition_id = partition->as().value.safeGet(); else partition_id = data.getPartitionIDFromQuery(partition, query_context); @@ -3945,17 +3945,17 @@ void StorageReplicatedMergeTree::sendRequestToLeaderReplica(const ASTPtr & query /// TODO: add setters and getters interface for database and table fields of AST auto new_query = query->clone(); - if (auto * alter = typeid_cast(new_query.get())) + if (auto * alter = new_query->as()) { alter->database = leader_address.database; alter->table = leader_address.table; } - else if (auto * optimize = typeid_cast(new_query.get())) + else if (auto * optimize = new_query->as()) { optimize->database = leader_address.database; optimize->table = leader_address.table; } - else if (auto * drop = typeid_cast(new_query.get()); drop->kind == ASTDropQuery::Kind::Truncate) + else if (auto * drop = new_query->as(); drop->kind == ASTDropQuery::Kind::Truncate) { drop->database = leader_address.database; drop->table = leader_address.table; diff --git a/dbms/src/Storages/StorageURL.cpp b/dbms/src/Storages/StorageURL.cpp index 3224527123b..71575f27b1b 100644 --- a/dbms/src/Storages/StorageURL.cpp +++ b/dbms/src/Storages/StorageURL.cpp @@ -202,12 +202,12 @@ void registerStorageURL(StorageFactory & factory) engine_args[0] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[0], args.local_context); - String url = static_cast(*engine_args[0]).value.safeGet(); + String url = engine_args[0]->as().value.safeGet(); Poco::URI uri(url); engine_args[1] = evaluateConstantExpressionOrIdentifierAsLiteral(engine_args[1], args.local_context); - String format_name = static_cast(*engine_args[1]).value.safeGet(); + String format_name = engine_args[1]->as().value.safeGet(); return StorageURL::create(uri, args.table_name, format_name, args.columns, args.context); }); diff --git a/dbms/src/Storages/StorageView.cpp b/dbms/src/Storages/StorageView.cpp index 97c085d16e3..dec57408746 100644 --- a/dbms/src/Storages/StorageView.cpp +++ b/dbms/src/Storages/StorageView.cpp @@ -54,7 +54,7 @@ BlockInputStreams StorageView::read( { auto new_inner_query = inner_query->clone(); auto new_outer_query = query_info.query->clone(); - auto new_outer_select = typeid_cast(new_outer_query.get()); + auto * new_outer_select = new_outer_query->as(); replaceTableNameWithSubquery(new_outer_select, new_inner_query); @@ -74,12 +74,12 @@ BlockInputStreams StorageView::read( void StorageView::replaceTableNameWithSubquery(ASTSelectQuery * select_query, ASTPtr & subquery) { - ASTTablesInSelectQueryElement * select_element = static_cast(select_query->tables->children[0].get()); + auto * select_element = select_query->tables->children[0]->as(); if (!select_element->table_expression) throw Exception("Logical error: incorrect table expression", ErrorCodes::LOGICAL_ERROR); - ASTTableExpression * table_expression = static_cast(select_element->table_expression.get()); + auto * table_expression = select_element->table_expression->as(); if (!table_expression->database_and_table_name) throw Exception("Logical error: incorrect table expression", ErrorCodes::LOGICAL_ERROR); diff --git a/dbms/src/Storages/StorageView.h b/dbms/src/Storages/StorageView.h index fed6664b8eb..afd9b5ce326 100644 --- a/dbms/src/Storages/StorageView.h +++ b/dbms/src/Storages/StorageView.h @@ -1,17 +1,15 @@ #pragma once -#include -#include #include +#include +#include + +#include namespace DB { -class IAST; -using ASTPtr = std::shared_ptr; - - class StorageView : public ext::shared_ptr_helper, public IStorage { public: diff --git a/dbms/src/Storages/StorageXDBC.cpp b/dbms/src/Storages/StorageXDBC.cpp index ac8f156305a..3b02b86ad7f 100644 --- a/dbms/src/Storages/StorageXDBC.cpp +++ b/dbms/src/Storages/StorageXDBC.cpp @@ -115,10 +115,10 @@ namespace BridgeHelperPtr bridge_helper = std::make_shared>(args.context, args.context.getSettingsRef().http_receive_timeout.value, - static_cast(*engine_args[0]).value.safeGet()); + engine_args[0]->as().value.safeGet()); return std::make_shared(args.table_name, - static_cast(*engine_args[1]).value.safeGet(), - static_cast(*engine_args[2]).value.safeGet(), + engine_args[1]->as().value.safeGet(), + engine_args[2]->as().value.safeGet(), args.columns, args.context, bridge_helper); diff --git a/dbms/src/Storages/System/StorageSystemTables.cpp b/dbms/src/Storages/System/StorageSystemTables.cpp index 498035ea7c0..3413e8609f4 100644 --- a/dbms/src/Storages/System/StorageSystemTables.cpp +++ b/dbms/src/Storages/System/StorageSystemTables.cpp @@ -240,7 +240,7 @@ protected: if (ast) { - const ASTCreateQuery & ast_create = typeid_cast(*ast); + const auto & ast_create = ast->as(); if (ast_create.storage) { engine_full = queryToString(*ast_create.storage); diff --git a/dbms/src/Storages/System/StorageSystemZooKeeper.cpp b/dbms/src/Storages/System/StorageSystemZooKeeper.cpp index dd7eb033514..a23d0a79957 100644 --- a/dbms/src/Storages/System/StorageSystemZooKeeper.cpp +++ b/dbms/src/Storages/System/StorageSystemZooKeeper.cpp @@ -44,7 +44,7 @@ NamesAndTypesList StorageSystemZooKeeper::getNamesAndTypes() static bool extractPathImpl(const IAST & elem, String & res) { - const ASTFunction * function = typeid_cast(&elem); + const auto * function = elem.as(); if (!function) return false; @@ -59,24 +59,24 @@ static bool extractPathImpl(const IAST & elem, String & res) if (function->name == "equals") { - const ASTExpressionList & args = typeid_cast(*function->arguments); + const auto & args = function->arguments->as(); const IAST * value; if (args.children.size() != 2) return false; const ASTIdentifier * ident; - if ((ident = typeid_cast(&*args.children.at(0)))) - value = &*args.children.at(1); - else if ((ident = typeid_cast(&*args.children.at(1)))) - value = &*args.children.at(0); + if ((ident = args.children.at(0)->as())) + value = args.children.at(1).get(); + else if ((ident = args.children.at(1)->as())) + value = args.children.at(0).get(); else return false; if (ident->name != "path") return false; - const ASTLiteral * literal = typeid_cast(value); + const auto * literal = value->as(); if (!literal) return false; @@ -95,7 +95,7 @@ static bool extractPathImpl(const IAST & elem, String & res) */ static String extractPath(const ASTPtr & query) { - const ASTSelectQuery & select = typeid_cast(*query); + const auto & select = query->as(); if (!select.where_expression) return ""; diff --git a/dbms/src/Storages/VirtualColumnUtils.cpp b/dbms/src/Storages/VirtualColumnUtils.cpp index b7ea5e66c37..9f634cc0eec 100644 --- a/dbms/src/Storages/VirtualColumnUtils.cpp +++ b/dbms/src/Storages/VirtualColumnUtils.cpp @@ -76,18 +76,17 @@ String chooseSuffixForSet(const NamesAndTypesList & columns, const std::vector(*ast); + auto & select = ast->as(); if (!select.with_expression_list) { select.with_expression_list = std::make_shared(); select.children.insert(select.children.begin(), select.with_expression_list); } - ASTExpressionList & with = typeid_cast(*select.with_expression_list); auto literal = std::make_shared(value); literal->alias = column_name; literal->prefer_alias_to_column_name = true; - with.children.push_back(literal); + select.with_expression_list->children.push_back(literal); } /// Verifying that the function depends only on the specified columns @@ -106,7 +105,7 @@ static bool isValidFunction(const ASTPtr & expression, const NameSet & columns) /// Extract all subfunctions of the main conjunction, but depending only on the specified columns static void extractFunctions(const ASTPtr & expression, const NameSet & columns, std::vector & result) { - const ASTFunction * function = typeid_cast(expression.get()); + const auto * function = expression->as(); if (function && function->name == "and") { for (size_t i = 0; i < function->arguments->children.size(); ++i) @@ -126,7 +125,7 @@ static ASTPtr buildWhereExpression(const ASTs & functions) if (functions.size() == 1) return functions[0]; ASTPtr new_query = std::make_shared(); - ASTFunction & new_function = typeid_cast(*new_query); + auto & new_function = new_query->as(); new_function.name = "and"; new_function.arguments = std::make_shared(); new_function.arguments->children = functions; @@ -136,7 +135,7 @@ static ASTPtr buildWhereExpression(const ASTs & functions) void filterBlockWithQuery(const ASTPtr & query, Block & block, const Context & context) { - const ASTSelectQuery & select = typeid_cast(*query); + const auto & select = query->as(); if (!select.where_expression && !select.prewhere_expression) return; diff --git a/dbms/src/Storages/getStructureOfRemoteTable.cpp b/dbms/src/Storages/getStructureOfRemoteTable.cpp index bf867c1cad8..e6e8cea7f78 100644 --- a/dbms/src/Storages/getStructureOfRemoteTable.cpp +++ b/dbms/src/Storages/getStructureOfRemoteTable.cpp @@ -39,7 +39,7 @@ ColumnsDescription getStructureOfRemoteTable( { if (shard_info.isLocal()) { - auto table_function = static_cast(table_func_ptr.get()); + const auto * table_function = table_func_ptr->as(); return TableFunctionFactory::instance().get(table_function->name, context)->execute(table_func_ptr, context)->getColumns(); } diff --git a/dbms/src/Storages/transformQueryForExternalDatabase.cpp b/dbms/src/Storages/transformQueryForExternalDatabase.cpp index aea176def3b..d54996515d3 100644 --- a/dbms/src/Storages/transformQueryForExternalDatabase.cpp +++ b/dbms/src/Storages/transformQueryForExternalDatabase.cpp @@ -21,10 +21,10 @@ static void replaceConstFunction(IAST & node, const Context & context, const Nam for (size_t i = 0; i < node.children.size(); ++i) { auto child = node.children[i]; - if (ASTExpressionList * exp_list = typeid_cast(&*child)) + if (auto * exp_list = child->as()) replaceConstFunction(*exp_list, context, all_columns); - if (ASTFunction * function = typeid_cast(&*child)) + if (auto * function = child->as()) { NamesAndTypesList source_columns = all_columns; ASTPtr query = function->ptr(); @@ -42,7 +42,7 @@ static void replaceConstFunction(IAST & node, const Context & context, const Nam static bool isCompatible(const IAST & node) { - if (const ASTFunction * function = typeid_cast(&node)) + if (const auto * function = node.as()) { String name = function->name; if (!(name == "and" @@ -66,7 +66,7 @@ static bool isCompatible(const IAST & node) return true; } - if (const ASTLiteral * literal = typeid_cast(&node)) + if (const auto * literal = node.as()) { /// Foreign databases often have no support for Array and Tuple literals. if (literal->value.getType() == Field::Types::Array @@ -76,7 +76,7 @@ static bool isCompatible(const IAST & node) return true; } - if (isIdentifier(&node)) + if (node.as()) return true; return false; @@ -112,7 +112,7 @@ String transformQueryForExternalDatabase( * copy only compatible parts of it. */ - ASTPtr & original_where = typeid_cast(*clone_query).where_expression; + auto & original_where = clone_query->as().where_expression; if (original_where) { replaceConstFunction(*original_where, context, available_columns); @@ -120,7 +120,7 @@ String transformQueryForExternalDatabase( { select->where_expression = original_where; } - else if (const ASTFunction * function = typeid_cast(original_where.get())) + else if (const auto * function = original_where->as()) { if (function->name == "and") { diff --git a/dbms/src/TableFunctions/ITableFunction.h b/dbms/src/TableFunctions/ITableFunction.h index ddf900fa65c..39ef7857a3a 100644 --- a/dbms/src/TableFunctions/ITableFunction.h +++ b/dbms/src/TableFunctions/ITableFunction.h @@ -1,14 +1,15 @@ #pragma once +#include + #include #include + namespace DB { class Context; -class IAST; -using ASTPtr = std::shared_ptr; class IStorage; using StoragePtr = std::shared_ptr; diff --git a/dbms/src/TableFunctions/ITableFunctionFileLike.cpp b/dbms/src/TableFunctions/ITableFunctionFileLike.cpp index 6495fbb92e3..fe8af831b56 100644 --- a/dbms/src/TableFunctions/ITableFunctionFileLike.cpp +++ b/dbms/src/TableFunctions/ITableFunctionFileLike.cpp @@ -22,12 +22,12 @@ namespace ErrorCodes StoragePtr ITableFunctionFileLike::executeImpl(const ASTPtr & ast_function, const Context & context) const { // Parse args - ASTs & args_func = typeid_cast(*ast_function).children; + ASTs & args_func = ast_function->children; if (args_func.size() != 1) throw Exception("Table function '" + getName() + "' must have arguments.", ErrorCodes::LOGICAL_ERROR); - ASTs & args = typeid_cast(*args_func.at(0)).children; + ASTs & args = args_func.at(0)->children; if (args.size() != 3) throw Exception("Table function '" + getName() + "' requires exactly 3 arguments: filename, format and structure.", @@ -36,9 +36,9 @@ StoragePtr ITableFunctionFileLike::executeImpl(const ASTPtr & ast_function, cons for (size_t i = 0; i < 3; ++i) args[i] = evaluateConstantExpressionOrIdentifierAsLiteral(args[i], context); - std::string filename = static_cast(*args[0]).value.safeGet(); - std::string format = static_cast(*args[1]).value.safeGet(); - std::string structure = static_cast(*args[2]).value.safeGet(); + std::string filename = args[0]->as().value.safeGet(); + std::string format = args[1]->as().value.safeGet(); + std::string structure = args[2]->as().value.safeGet(); // Create sample block std::vector structure_vals; diff --git a/dbms/src/TableFunctions/ITableFunctionXDBC.cpp b/dbms/src/TableFunctions/ITableFunctionXDBC.cpp index 455512310db..32011dc8f8a 100644 --- a/dbms/src/TableFunctions/ITableFunctionXDBC.cpp +++ b/dbms/src/TableFunctions/ITableFunctionXDBC.cpp @@ -29,12 +29,12 @@ namespace ErrorCodes StoragePtr ITableFunctionXDBC::executeImpl(const ASTPtr & ast_function, const Context & context) const { - const ASTFunction & args_func = typeid_cast(*ast_function); + const auto & args_func = ast_function->as(); if (!args_func.arguments) throw Exception("Table function '" + getName() + "' must have arguments.", ErrorCodes::LOGICAL_ERROR); - ASTs & args = typeid_cast(*args_func.arguments).children; + ASTs & args = args_func.arguments->children; if (args.size() != 2 && args.size() != 3) throw Exception("Table function '" + getName() + "' requires 2 or 3 arguments: " + getName() + "('DSN', table) or " + getName() + "('DSN', schema, table)", @@ -49,14 +49,14 @@ StoragePtr ITableFunctionXDBC::executeImpl(const ASTPtr & ast_function, const Co if (args.size() == 3) { - connection_string = static_cast(*args[0]).value.safeGet(); - schema_name = static_cast(*args[1]).value.safeGet(); - table_name = static_cast(*args[2]).value.safeGet(); + connection_string = args[0]->as().value.safeGet(); + schema_name = args[1]->as().value.safeGet(); + table_name = args[2]->as().value.safeGet(); } else if (args.size() == 2) { - connection_string = static_cast(*args[0]).value.safeGet(); - table_name = static_cast(*args[1]).value.safeGet(); + connection_string = args[0]->as().value.safeGet(); + table_name = args[1]->as().value.safeGet(); } /* Infer external table structure */ diff --git a/dbms/src/TableFunctions/TableFunctionCatBoostPool.cpp b/dbms/src/TableFunctions/TableFunctionCatBoostPool.cpp index 09ab2d3d7bb..23ff34be827 100644 --- a/dbms/src/TableFunctions/TableFunctionCatBoostPool.cpp +++ b/dbms/src/TableFunctions/TableFunctionCatBoostPool.cpp @@ -18,7 +18,7 @@ namespace ErrorCodes StoragePtr TableFunctionCatBoostPool::executeImpl(const ASTPtr & ast_function, const Context & context) const { - ASTs & args_func = typeid_cast(*ast_function).children; + ASTs & args_func = ast_function->children; std::string err = "Table function '" + getName() + "' requires 2 parameters: " + "column descriptions file, dataset description file"; @@ -26,14 +26,14 @@ StoragePtr TableFunctionCatBoostPool::executeImpl(const ASTPtr & ast_function, c if (args_func.size() != 1) throw Exception(err, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - ASTs & args = typeid_cast(*args_func.at(0)).children; + ASTs & args = args_func.at(0)->children; if (args.size() != 2) throw Exception(err, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); auto getStringLiteral = [](const IAST & node, const char * description) { - auto lit = typeid_cast(&node); + const auto * lit = node.as(); if (!lit) throw Exception(description + String(" must be string literal (in single quotes)."), ErrorCodes::BAD_ARGUMENTS); diff --git a/dbms/src/TableFunctions/TableFunctionMerge.cpp b/dbms/src/TableFunctions/TableFunctionMerge.cpp index b5dace28d41..9dfd94b6512 100644 --- a/dbms/src/TableFunctions/TableFunctionMerge.cpp +++ b/dbms/src/TableFunctions/TableFunctionMerge.cpp @@ -56,14 +56,14 @@ static NamesAndTypesList chooseColumns(const String & source_database, const Str StoragePtr TableFunctionMerge::executeImpl(const ASTPtr & ast_function, const Context & context) const { - ASTs & args_func = typeid_cast(*ast_function).children; + ASTs & args_func = ast_function->children; if (args_func.size() != 1) throw Exception("Table function 'merge' requires exactly 2 arguments" " - name of source database and regexp for table names.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - ASTs & args = typeid_cast(*args_func.at(0)).children; + ASTs & args = args_func.at(0)->children; if (args.size() != 2) throw Exception("Table function 'merge' requires exactly 2 arguments" @@ -73,8 +73,8 @@ StoragePtr TableFunctionMerge::executeImpl(const ASTPtr & ast_function, const Co args[0] = evaluateConstantExpressionOrIdentifierAsLiteral(args[0], context); args[1] = evaluateConstantExpressionAsLiteral(args[1], context); - String source_database = static_cast(*args[0]).value.safeGet(); - String table_name_regexp = static_cast(*args[1]).value.safeGet(); + String source_database = args[0]->as().value.safeGet(); + String table_name_regexp = args[1]->as().value.safeGet(); auto res = StorageMerge::create( getName(), diff --git a/dbms/src/TableFunctions/TableFunctionMySQL.cpp b/dbms/src/TableFunctions/TableFunctionMySQL.cpp index 26dc6e200b9..e335af45bcb 100644 --- a/dbms/src/TableFunctions/TableFunctionMySQL.cpp +++ b/dbms/src/TableFunctions/TableFunctionMySQL.cpp @@ -87,12 +87,12 @@ DataTypePtr getDataType(const String & mysql_data_type, bool is_nullable, bool i StoragePtr TableFunctionMySQL::executeImpl(const ASTPtr & ast_function, const Context & context) const { - const ASTFunction & args_func = typeid_cast(*ast_function); + const auto & args_func = ast_function->as(); if (!args_func.arguments) throw Exception("Table function 'mysql' must have arguments.", ErrorCodes::LOGICAL_ERROR); - ASTs & args = typeid_cast(*args_func.arguments).children; + ASTs & args = args_func.arguments->children; if (args.size() < 5 || args.size() > 7) throw Exception("Table function 'mysql' requires 5-7 parameters: MySQL('host:port', database, table, 'user', 'password'[, replace_query, 'on_duplicate_clause']).", @@ -101,18 +101,18 @@ StoragePtr TableFunctionMySQL::executeImpl(const ASTPtr & ast_function, const Co for (size_t i = 0; i < args.size(); ++i) args[i] = evaluateConstantExpressionOrIdentifierAsLiteral(args[i], context); - std::string host_port = static_cast(*args[0]).value.safeGet(); - std::string database_name = static_cast(*args[1]).value.safeGet(); - std::string table_name = static_cast(*args[2]).value.safeGet(); - std::string user_name = static_cast(*args[3]).value.safeGet(); - std::string password = static_cast(*args[4]).value.safeGet(); + std::string host_port = args[0]->as().value.safeGet(); + std::string database_name = args[1]->as().value.safeGet(); + std::string table_name = args[2]->as().value.safeGet(); + std::string user_name = args[3]->as().value.safeGet(); + std::string password = args[4]->as().value.safeGet(); bool replace_query = false; std::string on_duplicate_clause; if (args.size() >= 6) - replace_query = static_cast(*args[5]).value.safeGet() > 0; + replace_query = args[5]->as().value.safeGet() > 0; if (args.size() == 7) - on_duplicate_clause = static_cast(*args[6]).value.safeGet(); + on_duplicate_clause = args[6]->as().value.safeGet(); if (replace_query && !on_duplicate_clause.empty()) throw Exception( diff --git a/dbms/src/TableFunctions/TableFunctionNumbers.cpp b/dbms/src/TableFunctions/TableFunctionNumbers.cpp index 8226542d9ee..a02cd904882 100644 --- a/dbms/src/TableFunctions/TableFunctionNumbers.cpp +++ b/dbms/src/TableFunctions/TableFunctionNumbers.cpp @@ -19,7 +19,7 @@ namespace ErrorCodes StoragePtr TableFunctionNumbers::executeImpl(const ASTPtr & ast_function, const Context & context) const { - if (const ASTFunction * function = typeid_cast(ast_function.get())) + if (const auto * function = ast_function->as()) { auto arguments = function->arguments->children; @@ -45,7 +45,7 @@ void registerTableFunctionNumbers(TableFunctionFactory & factory) UInt64 TableFunctionNumbers::evaluateArgument(const Context & context, ASTPtr & argument) const { - return static_cast(*evaluateConstantExpressionOrIdentifierAsLiteral(argument, context)).value.safeGet(); + return evaluateConstantExpressionOrIdentifierAsLiteral(argument, context)->as().value.safeGet(); } } diff --git a/dbms/src/TableFunctions/TableFunctionRemote.cpp b/dbms/src/TableFunctions/TableFunctionRemote.cpp index 716819c836d..21611500eb7 100644 --- a/dbms/src/TableFunctions/TableFunctionRemote.cpp +++ b/dbms/src/TableFunctions/TableFunctionRemote.cpp @@ -27,12 +27,12 @@ namespace ErrorCodes StoragePtr TableFunctionRemote::executeImpl(const ASTPtr & ast_function, const Context & context) const { - ASTs & args_func = typeid_cast(*ast_function).children; + ASTs & args_func = ast_function->children; if (args_func.size() != 1) throw Exception(help_message, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - ASTs & args = typeid_cast(*args_func.at(0)).children; + ASTs & args = args_func.at(0)->children; const size_t max_args = is_cluster_function ? 3 : 5; if (args.size() < 2 || args.size() > max_args) @@ -50,7 +50,7 @@ StoragePtr TableFunctionRemote::executeImpl(const ASTPtr & ast_function, const C auto getStringLiteral = [](const IAST & node, const char * description) { - const ASTLiteral * lit = typeid_cast(&node); + const auto * lit = node.as(); if (!lit) throw Exception(description + String(" must be string literal (in single quotes)."), ErrorCodes::BAD_ARGUMENTS); @@ -63,7 +63,7 @@ StoragePtr TableFunctionRemote::executeImpl(const ASTPtr & ast_function, const C if (is_cluster_function) { ASTPtr ast_name = evaluateConstantExpressionOrIdentifierAsLiteral(args[arg_num], context); - cluster_name = static_cast(*ast_name).value.safeGet(); + cluster_name = ast_name->as().value.safeGet(); } else { @@ -74,7 +74,7 @@ StoragePtr TableFunctionRemote::executeImpl(const ASTPtr & ast_function, const C args[arg_num] = evaluateConstantExpressionOrIdentifierAsLiteral(args[arg_num], context); - const auto function = typeid_cast(args[arg_num].get()); + const auto * function = args[arg_num]->as(); if (function && TableFunctionFactory::instance().isTableFunctionName(function->name)) { @@ -83,7 +83,7 @@ StoragePtr TableFunctionRemote::executeImpl(const ASTPtr & ast_function, const C } else { - remote_database = static_cast(*args[arg_num]).value.safeGet(); + remote_database = args[arg_num]->as().value.safeGet(); ++arg_num; @@ -103,7 +103,7 @@ StoragePtr TableFunctionRemote::executeImpl(const ASTPtr & ast_function, const C else { args[arg_num] = evaluateConstantExpressionOrIdentifierAsLiteral(args[arg_num], context); - remote_table = static_cast(*args[arg_num]).value.safeGet(); + remote_table = args[arg_num]->as().value.safeGet(); ++arg_num; } }