diff --git a/src/Interpreters/InterpreterSelectQuery.cpp b/src/Interpreters/InterpreterSelectQuery.cpp index 64a58e33231..23cb753e96f 100644 --- a/src/Interpreters/InterpreterSelectQuery.cpp +++ b/src/Interpreters/InterpreterSelectQuery.cpp @@ -958,28 +958,16 @@ void InterpreterSelectQuery::executeFetchColumns( const Settings & settings = context->getSettingsRef(); /// Optimization for trivial query like SELECT count() FROM table. - auto check_trivial_count_query = [&]() -> std::optional + bool optimize_trivial_count = + syntax_analyzer_result->optimize_trivial_count && storage && + processing_stage == QueryProcessingStage::FetchColumns && + query_analyzer->hasAggregation() && (query_analyzer->aggregates().size() == 1) && + typeid_cast(query_analyzer->aggregates()[0].function.get()); + + if (optimize_trivial_count) { - if (!settings.optimize_trivial_count_query || !syntax_analyzer_result->maybe_optimize_trivial_count || !storage - || query.sampleSize() || query.sampleOffset() || query.final() || query.prewhere() || query.where() || query.groupBy() - || !query_analyzer->hasAggregation() || processing_stage != QueryProcessingStage::FetchColumns) - return {}; - - const AggregateDescriptions & aggregates = query_analyzer->aggregates(); - - if (aggregates.size() != 1) - return {}; - - const AggregateDescription & desc = aggregates[0]; - if (typeid_cast(desc.function.get())) - return desc; - - return {}; - }; - - if (auto desc = check_trivial_count_query()) - { - auto func = desc->function; + const auto & desc = query_analyzer->aggregates()[0]; + const auto & func = desc.function; std::optional num_rows = storage->totalRows(); if (num_rows) { @@ -998,13 +986,13 @@ void InterpreterSelectQuery::executeFetchColumns( column->insertFrom(place); auto header = analysis_result.before_aggregation->getSampleBlock(); - size_t arguments_size = desc->argument_names.size(); + size_t arguments_size = desc.argument_names.size(); DataTypes argument_types(arguments_size); for (size_t j = 0; j < arguments_size; ++j) - argument_types[j] = header.getByName(desc->argument_names[j]).type; + argument_types[j] = header.getByName(desc.argument_names[j]).type; Block block_with_count{ - {std::move(column), std::make_shared(func, argument_types, desc->parameters), desc->column_name}}; + {std::move(column), std::make_shared(func, argument_types, desc.parameters), desc.column_name}}; auto istream = std::make_shared(block_with_count); pipeline.init(Pipe(std::make_shared(istream))); diff --git a/src/Interpreters/SyntaxAnalyzer.cpp b/src/Interpreters/SyntaxAnalyzer.cpp index 831379090ad..5f1bf79e053 100644 --- a/src/Interpreters/SyntaxAnalyzer.cpp +++ b/src/Interpreters/SyntaxAnalyzer.cpp @@ -598,7 +598,7 @@ void SyntaxAnalyzerResult::collectSourceColumns(bool add_special) /// Calculate which columns are required to execute the expression. /// Then, delete all other columns from the list of available columns. /// After execution, columns will only contain the list of columns needed to read from the table. -void SyntaxAnalyzerResult::collectUsedColumns(const ASTPtr & query) +void SyntaxAnalyzerResult::collectUsedColumns(const ASTPtr & query, bool is_select) { /// We calculate required_source_columns with source_columns modifications and swap them on exit required_source_columns = source_columns; @@ -648,12 +648,11 @@ void SyntaxAnalyzerResult::collectUsedColumns(const ASTPtr & query) required.insert(column_name_type.name); } - const auto * select_query = query->as(); - /// You need to read at least one column to find the number of rows. - if (select_query && required.empty()) + if (is_select && required.empty()) { - maybe_optimize_trivial_count = true; + optimize_trivial_count = true; + /// We will find a column with minimum . /// Because it is the column that is cheapest to read. struct ColumnSizeTuple @@ -662,12 +661,14 @@ void SyntaxAnalyzerResult::collectUsedColumns(const ASTPtr & query) size_t type_size; size_t uncompressed_size; String name; + bool operator<(const ColumnSizeTuple & that) const { return std::tie(compressed_size, type_size, uncompressed_size) < std::tie(that.compressed_size, that.type_size, that.uncompressed_size); } }; + std::vector columns; if (storage) { @@ -681,6 +682,7 @@ void SyntaxAnalyzerResult::collectUsedColumns(const ASTPtr & query) columns.emplace_back(ColumnSizeTuple{c->second.data_compressed, type_size, c->second.data_uncompressed, source_column.name}); } } + if (!columns.empty()) required.insert(std::min_element(columns.begin(), columns.end())->name); else @@ -760,6 +762,7 @@ void SyntaxAnalyzerResult::collectUsedColumns(const ASTPtr & query) required_source_columns.swap(source_columns); } + SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyzeSelect( ASTPtr & query, SyntaxAnalyzerResult && result, @@ -848,7 +851,14 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyzeSelect( } result.aggregates = getAggregates(query, *select_query); - result.collectUsedColumns(query); + result.collectUsedColumns(query, true); + + if (result.optimize_trivial_count) + result.optimize_trivial_count = settings.optimize_trivial_count_query && + !select_query->where() && !select_query->prewhere() && !select_query->groupBy() && !select_query->having() && + !select_query->sampleSize() && !select_query->sampleOffset() && !select_query->final() && + (tables_with_column_names.size() < 2 || isLeft(result.analyzed_join->kind())); + return std::make_shared(result); } @@ -882,7 +892,7 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyze(ASTPtr & query, const NamesAndTy else assertNoAggregates(query, "in wrong place"); - result.collectUsedColumns(query); + result.collectUsedColumns(query, false); return std::make_shared(result); } diff --git a/src/Interpreters/SyntaxAnalyzer.h b/src/Interpreters/SyntaxAnalyzer.h index abacb25ac4d..175c2db295a 100644 --- a/src/Interpreters/SyntaxAnalyzer.h +++ b/src/Interpreters/SyntaxAnalyzer.h @@ -46,11 +46,11 @@ struct SyntaxAnalyzerResult /// Predicate optimizer overrides the sub queries bool rewrite_subqueries = false; + bool optimize_trivial_count = false; + /// Results of scalar sub queries Scalars scalars; - bool maybe_optimize_trivial_count = false; - SyntaxAnalyzerResult(const NamesAndTypesList & source_columns_, ConstStoragePtr storage_ = {}, bool add_special = true) : storage(storage_) , source_columns(source_columns_) @@ -59,7 +59,7 @@ struct SyntaxAnalyzerResult } void collectSourceColumns(bool add_special); - void collectUsedColumns(const ASTPtr & query); + void collectUsedColumns(const ASTPtr & query, bool is_select); Names requiredSourceColumns() const { return required_source_columns.getNames(); } const Scalars & getScalars() const { return scalars; } }; diff --git a/tests/queries/0_stateless/01143_trivial_count_with_join.reference b/tests/queries/0_stateless/01143_trivial_count_with_join.reference new file mode 100644 index 00000000000..9c3f6a570ce --- /dev/null +++ b/tests/queries/0_stateless/01143_trivial_count_with_join.reference @@ -0,0 +1,5 @@ +4 +4 +4 +4 +4 diff --git a/tests/queries/0_stateless/01143_trivial_count_with_join.sql b/tests/queries/0_stateless/01143_trivial_count_with_join.sql new file mode 100644 index 00000000000..d31750e37dc --- /dev/null +++ b/tests/queries/0_stateless/01143_trivial_count_with_join.sql @@ -0,0 +1,10 @@ +drop table if exists t; +create table t engine Memory as select * from numbers(2); + +select count(*) from t, numbers(2) r; +select count(*) from t cross join numbers(2) r; +select count() from t cross join numbers(2) r; +select count(t.number) from t cross join numbers(2) r; +select count(r.number) from t cross join numbers(2) r; + +drop table t;