fix trivial count with join

This commit is contained in:
Artem Zuikov 2020-06-05 01:01:40 +03:00
parent 52aada4b80
commit 6d211bec19
5 changed files with 47 additions and 34 deletions

View File

@ -958,28 +958,16 @@ void InterpreterSelectQuery::executeFetchColumns(
const Settings & settings = context->getSettingsRef();
/// Optimization for trivial query like SELECT count() FROM table.
auto check_trivial_count_query = [&]() -> std::optional<AggregateDescription>
bool optimize_trivial_count =
syntax_analyzer_result->optimize_trivial_count && storage &&
processing_stage == QueryProcessingStage::FetchColumns &&
query_analyzer->hasAggregation() && (query_analyzer->aggregates().size() == 1) &&
typeid_cast<AggregateFunctionCount *>(query_analyzer->aggregates()[0].function.get());
if (optimize_trivial_count)
{
if (!settings.optimize_trivial_count_query || !syntax_analyzer_result->maybe_optimize_trivial_count || !storage
|| query.sampleSize() || query.sampleOffset() || query.final() || query.prewhere() || query.where() || query.groupBy()
|| !query_analyzer->hasAggregation() || processing_stage != QueryProcessingStage::FetchColumns)
return {};
const AggregateDescriptions & aggregates = query_analyzer->aggregates();
if (aggregates.size() != 1)
return {};
const AggregateDescription & desc = aggregates[0];
if (typeid_cast<AggregateFunctionCount *>(desc.function.get()))
return desc;
return {};
};
if (auto desc = check_trivial_count_query())
{
auto func = desc->function;
auto & desc = query_analyzer->aggregates()[0];
auto & func = desc.function;
std::optional<UInt64> num_rows = storage->totalRows();
if (num_rows)
{
@ -998,13 +986,13 @@ void InterpreterSelectQuery::executeFetchColumns(
column->insertFrom(place);
auto header = analysis_result.before_aggregation->getSampleBlock();
size_t arguments_size = desc->argument_names.size();
size_t arguments_size = desc.argument_names.size();
DataTypes argument_types(arguments_size);
for (size_t j = 0; j < arguments_size; ++j)
argument_types[j] = header.getByName(desc->argument_names[j]).type;
argument_types[j] = header.getByName(desc.argument_names[j]).type;
Block block_with_count{
{std::move(column), std::make_shared<DataTypeAggregateFunction>(func, argument_types, desc->parameters), desc->column_name}};
{std::move(column), std::make_shared<DataTypeAggregateFunction>(func, argument_types, desc.parameters), desc.column_name}};
auto istream = std::make_shared<OneBlockInputStream>(block_with_count);
pipeline.init(Pipe(std::make_shared<SourceFromInputStream>(istream)));

View File

@ -598,7 +598,7 @@ void SyntaxAnalyzerResult::collectSourceColumns(bool add_special)
/// Calculate which columns are required to execute the expression.
/// Then, delete all other columns from the list of available columns.
/// After execution, columns will only contain the list of columns needed to read from the table.
void SyntaxAnalyzerResult::collectUsedColumns(const ASTPtr & query)
void SyntaxAnalyzerResult::collectUsedColumns(const ASTPtr & query, bool is_select)
{
/// We calculate required_source_columns with source_columns modifications and swap them on exit
required_source_columns = source_columns;
@ -648,12 +648,11 @@ void SyntaxAnalyzerResult::collectUsedColumns(const ASTPtr & query)
required.insert(column_name_type.name);
}
const auto * select_query = query->as<ASTSelectQuery>();
/// You need to read at least one column to find the number of rows.
if (select_query && required.empty())
if (is_select && required.empty())
{
maybe_optimize_trivial_count = true;
optimize_trivial_count = true;
/// We will find a column with minimum <compressed_size, type_size, uncompressed_size>.
/// Because it is the column that is cheapest to read.
struct ColumnSizeTuple
@ -662,12 +661,14 @@ void SyntaxAnalyzerResult::collectUsedColumns(const ASTPtr & query)
size_t type_size;
size_t uncompressed_size;
String name;
bool operator<(const ColumnSizeTuple & that) const
{
return std::tie(compressed_size, type_size, uncompressed_size)
< std::tie(that.compressed_size, that.type_size, that.uncompressed_size);
}
};
std::vector<ColumnSizeTuple> columns;
if (storage)
{
@ -681,6 +682,7 @@ void SyntaxAnalyzerResult::collectUsedColumns(const ASTPtr & query)
columns.emplace_back(ColumnSizeTuple{c->second.data_compressed, type_size, c->second.data_uncompressed, source_column.name});
}
}
if (!columns.empty())
required.insert(std::min_element(columns.begin(), columns.end())->name);
else
@ -760,6 +762,7 @@ void SyntaxAnalyzerResult::collectUsedColumns(const ASTPtr & query)
required_source_columns.swap(source_columns);
}
SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyzeSelect(
ASTPtr & query,
SyntaxAnalyzerResult && result,
@ -848,7 +851,14 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyzeSelect(
}
result.aggregates = getAggregates(query, *select_query);
result.collectUsedColumns(query);
result.collectUsedColumns(query, true);
if (result.optimize_trivial_count)
result.optimize_trivial_count = settings.optimize_trivial_count_query &&
!select_query->where() && !select_query->prewhere() && !select_query->groupBy() && !select_query->having() &&
!select_query->sampleSize() && !select_query->sampleOffset() && !select_query->final() &&
(tables_with_column_names.size() < 2 || isLeft(result.analyzed_join->kind()));
return std::make_shared<const SyntaxAnalyzerResult>(result);
}
@ -882,7 +892,7 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyze(ASTPtr & query, const NamesAndTy
else
assertNoAggregates(query, "in wrong place");
result.collectUsedColumns(query);
result.collectUsedColumns(query, false);
return std::make_shared<const SyntaxAnalyzerResult>(result);
}

View File

@ -46,11 +46,11 @@ struct SyntaxAnalyzerResult
/// Predicate optimizer overrides the sub queries
bool rewrite_subqueries = false;
bool optimize_trivial_count = false;
/// Results of scalar sub queries
Scalars scalars;
bool maybe_optimize_trivial_count = false;
SyntaxAnalyzerResult(const NamesAndTypesList & source_columns_, ConstStoragePtr storage_ = {}, bool add_special = true)
: storage(storage_)
, source_columns(source_columns_)
@ -59,7 +59,7 @@ struct SyntaxAnalyzerResult
}
void collectSourceColumns(bool add_special);
void collectUsedColumns(const ASTPtr & query);
void collectUsedColumns(const ASTPtr & query, bool is_select);
Names requiredSourceColumns() const { return required_source_columns.getNames(); }
const Scalars & getScalars() const { return scalars; }
};

View File

@ -0,0 +1,5 @@
4
4
4
4
4

View File

@ -0,0 +1,10 @@
drop table if exists t;
create table t engine Memory as select * from numbers(2);
select count(*) from t, numbers(2) r;
select count(*) from t cross join numbers(2) r;
select count() from t cross join numbers(2) r;
select count(t.number) from t cross join numbers(2) r;
select count(r.number) from t cross join numbers(2) r;
drop table t;