mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-15 12:14:18 +00:00
revert removing fuse_sum_count ast optimization
This commit is contained in:
parent
7daf5200f0
commit
14e8daf078
@ -193,7 +193,7 @@ void addQueryTreePasses(QueryTreePassManager & manager)
|
||||
manager.addPass(std::make_unique<OrderByTupleEliminationPass>());
|
||||
manager.addPass(std::make_unique<OrderByLimitByDuplicateEliminationPass>());
|
||||
|
||||
if (settings.optimize_syntax_fuse_functions)
|
||||
if (settings.optimize_fuse_sum_count_avg)
|
||||
manager.addPass(std::make_unique<FuseFunctionsPass>());
|
||||
}
|
||||
|
||||
|
@ -522,7 +522,8 @@ static constexpr UInt64 operator""_GiB(unsigned long long value)
|
||||
M(Bool, allow_non_metadata_alters, true, "Allow to execute alters which affects not only tables metadata, but also data on disk", 0) \
|
||||
M(Bool, enable_global_with_statement, true, "Propagate WITH statements to UNION queries and all subqueries", 0) \
|
||||
M(Bool, aggregate_functions_null_for_empty, false, "Rewrite all aggregate functions in a query, adding -OrNull suffix to them", 0) \
|
||||
M(Bool, optimize_syntax_fuse_functions, false, " Allow apply query tree optimisation: fuse aggregate functions (e.g. replace functions `sum, avg, count` with identical arguments into one `sumCount`)", 0) \
|
||||
M(Bool, optimize_syntax_fuse_functions, false, "Not ready for production, do not use. Allow apply syntax optimisation: fuse aggregate functions", 0) \
|
||||
M(Bool, optimize_fuse_sum_count_avg, false, "Replace calls of functions `sum`, `avg`, `count` with identical arguments into one `sumCount`", 0) \
|
||||
M(Bool, flatten_nested, true, "If true, columns of type Nested will be flatten to separate array columns instead of one array of tuples", 0) \
|
||||
M(Bool, asterisk_include_materialized_columns, false, "Include MATERIALIZED columns for wildcard query", 0) \
|
||||
M(Bool, asterisk_include_alias_columns, false, "Include ALIAS columns for wildcard query", 0) \
|
||||
@ -692,9 +693,8 @@ static constexpr UInt64 operator""_GiB(unsigned long long value)
|
||||
MAKE_OBSOLETE(M, UInt64, background_message_broker_schedule_pool_size, 16) \
|
||||
MAKE_OBSOLETE(M, UInt64, background_distributed_schedule_pool_size, 16) \
|
||||
MAKE_OBSOLETE(M, DefaultDatabaseEngine, default_database_engine, DefaultDatabaseEngine::Atomic) \
|
||||
MAKE_OBSOLETE(M, UInt64, max_pipeline_depth, 0) \
|
||||
MAKE_OBSOLETE(M, UInt64, max_pipeline_depth, 0) \
|
||||
MAKE_OBSOLETE(M, Seconds, temporary_live_view_timeout, 1) \
|
||||
MAKE_OBSOLETE(M, Bool, optimize_fuse_sum_count_avg, false) \
|
||||
|
||||
/** The section above is for obsolete settings. Do not add anything there. */
|
||||
|
||||
|
@ -203,8 +203,73 @@ struct CustomizeAggregateFunctionsMoveSuffixData
|
||||
}
|
||||
};
|
||||
|
||||
struct FuseSumCountAggregates
|
||||
{
|
||||
std::vector<ASTFunction *> sums {};
|
||||
std::vector<ASTFunction *> counts {};
|
||||
std::vector<ASTFunction *> avgs {};
|
||||
|
||||
void addFuncNode(ASTFunction * func)
|
||||
{
|
||||
if (func->name == "sum")
|
||||
sums.push_back(func);
|
||||
else if (func->name == "count")
|
||||
counts.push_back(func);
|
||||
else
|
||||
{
|
||||
assert(func->name == "avg");
|
||||
avgs.push_back(func);
|
||||
}
|
||||
}
|
||||
|
||||
bool canBeFused() const
|
||||
{
|
||||
// Need at least two different kinds of functions to fuse.
|
||||
if (sums.empty() && counts.empty())
|
||||
return false;
|
||||
if (sums.empty() && avgs.empty())
|
||||
return false;
|
||||
if (counts.empty() && avgs.empty())
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
struct FuseSumCountAggregatesVisitorData
|
||||
{
|
||||
using TypeToVisit = ASTFunction;
|
||||
|
||||
std::unordered_map<String, FuseSumCountAggregates> fuse_map;
|
||||
|
||||
void visit(ASTFunction & func, ASTPtr &)
|
||||
{
|
||||
if (func.name == "sum" || func.name == "avg" || func.name == "count")
|
||||
{
|
||||
if (func.arguments->children.empty())
|
||||
return;
|
||||
|
||||
// Probably we can extend it to match count() for non-nullable argument
|
||||
// to sum/avg with any other argument. Now we require strict match.
|
||||
const auto argument = func.arguments->children.at(0)->getColumnName();
|
||||
auto it = fuse_map.find(argument);
|
||||
if (it != fuse_map.end())
|
||||
{
|
||||
it->second.addFuncNode(&func);
|
||||
}
|
||||
else
|
||||
{
|
||||
FuseSumCountAggregates funcs{};
|
||||
funcs.addFuncNode(&func);
|
||||
fuse_map[argument] = funcs;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using CustomizeAggregateFunctionsOrNullVisitor = InDepthNodeVisitor<OneTypeMatcher<CustomizeAggregateFunctionsSuffixData>, true>;
|
||||
using CustomizeAggregateFunctionsMoveOrNullVisitor = InDepthNodeVisitor<OneTypeMatcher<CustomizeAggregateFunctionsMoveSuffixData>, true>;
|
||||
using FuseSumCountAggregatesVisitor = InDepthNodeVisitor<OneTypeMatcher<FuseSumCountAggregatesVisitorData>, true>;
|
||||
|
||||
|
||||
struct ExistsExpressionData
|
||||
{
|
||||
@ -276,6 +341,52 @@ void translateQualifiedNames(ASTPtr & query, const ASTSelectQuery & select_query
|
||||
throw Exception("Empty list of columns in SELECT query", ErrorCodes::EMPTY_LIST_OF_COLUMNS_QUERIED);
|
||||
}
|
||||
|
||||
// Replaces one avg/sum/count function with an appropriate expression with
|
||||
// sumCount().
|
||||
void replaceWithSumCount(String column_name, ASTFunction & func)
|
||||
{
|
||||
auto func_base = makeASTFunction("sumCount", std::make_shared<ASTIdentifier>(column_name));
|
||||
auto exp_list = std::make_shared<ASTExpressionList>();
|
||||
if (func.name == "sum" || func.name == "count")
|
||||
{
|
||||
/// Rewrite "sum" to sumCount().1, rewrite "count" to sumCount().2
|
||||
UInt8 idx = (func.name == "sum" ? 1 : 2);
|
||||
func.name = "tupleElement";
|
||||
exp_list->children.push_back(func_base);
|
||||
exp_list->children.push_back(std::make_shared<ASTLiteral>(idx));
|
||||
}
|
||||
else
|
||||
{
|
||||
/// Rewrite "avg" to sumCount().1 / sumCount().2
|
||||
auto new_arg1 = makeASTFunction("tupleElement", func_base, std::make_shared<ASTLiteral>(UInt8(1)));
|
||||
auto new_arg2 = makeASTFunction("CAST",
|
||||
makeASTFunction("tupleElement", func_base, std::make_shared<ASTLiteral>(static_cast<UInt8>(2))),
|
||||
std::make_shared<ASTLiteral>("Float64"));
|
||||
|
||||
func.name = "divide";
|
||||
exp_list->children.push_back(new_arg1);
|
||||
exp_list->children.push_back(new_arg2);
|
||||
}
|
||||
func.arguments = exp_list;
|
||||
func.children.push_back(func.arguments);
|
||||
}
|
||||
|
||||
void fuseSumCountAggregates(std::unordered_map<String, FuseSumCountAggregates> & fuse_map)
|
||||
{
|
||||
for (auto & it : fuse_map)
|
||||
{
|
||||
if (it.second.canBeFused())
|
||||
{
|
||||
for (auto & func: it.second.sums)
|
||||
replaceWithSumCount(it.first, *func);
|
||||
for (auto & func: it.second.avgs)
|
||||
replaceWithSumCount(it.first, *func);
|
||||
for (auto & func: it.second.counts)
|
||||
replaceWithSumCount(it.first, *func);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool hasArrayJoin(const ASTPtr & ast)
|
||||
{
|
||||
if (const ASTFunction * function = ast->as<ASTFunction>())
|
||||
@ -1329,6 +1440,17 @@ void TreeRewriter::normalize(
|
||||
CustomizeGlobalNotInVisitor(data_global_not_null_in).visit(query);
|
||||
}
|
||||
|
||||
// Try to fuse sum/avg/count with identical arguments to one sumCount call,
|
||||
// if we have at least two different functions. E.g. we will replace sum(x)
|
||||
// and count(x) with sumCount(x).1 and sumCount(x).2, and sumCount() will
|
||||
// be calculated only once because of CSE.
|
||||
if (settings.optimize_fuse_sum_count_avg && settings.optimize_syntax_fuse_functions)
|
||||
{
|
||||
FuseSumCountAggregatesVisitor::Data data;
|
||||
FuseSumCountAggregatesVisitor(data).visit(query);
|
||||
fuseSumCountAggregates(data.fuse_map);
|
||||
}
|
||||
|
||||
/// Rewrite all aggregate functions to add -OrNull suffix to them
|
||||
if (settings.aggregate_functions_null_for_empty)
|
||||
{
|
||||
|
@ -1,93 +1,12 @@
|
||||
5 6 3 2 2 7 2
|
||||
5 6 3 2 2 7 2
|
||||
QUERY id: 0
|
||||
PROJECTION COLUMNS
|
||||
sum(plus(a, 1)) Nullable(Int64)
|
||||
sum(b) Int64
|
||||
count(b) UInt64
|
||||
avg(b) Float64
|
||||
count(plus(a, 1)) UInt64
|
||||
sum(plus(a, 2)) Nullable(Int64)
|
||||
count(a) UInt64
|
||||
PROJECTION
|
||||
LIST id: 1, nodes: 7
|
||||
FUNCTION id: 2, function_name: tupleElement, function_type: ordinary, result_type: Nullable(Int64)
|
||||
ARGUMENTS
|
||||
LIST id: 3, nodes: 2
|
||||
FUNCTION id: 4, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
|
||||
ARGUMENTS
|
||||
LIST id: 5, nodes: 1
|
||||
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: Nullable(Int16)
|
||||
ARGUMENTS
|
||||
LIST id: 7, nodes: 2
|
||||
COLUMN id: 8, column_name: a, result_type: Nullable(Int8), source_id: 9
|
||||
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
|
||||
CONSTANT id: 11, constant_value: UInt64_1, constant_value_type: UInt8
|
||||
FUNCTION id: 12, function_name: tupleElement, function_type: ordinary, result_type: Int64
|
||||
ARGUMENTS
|
||||
LIST id: 13, nodes: 2
|
||||
FUNCTION id: 14, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
|
||||
ARGUMENTS
|
||||
LIST id: 15, nodes: 1
|
||||
COLUMN id: 16, column_name: b, result_type: Int8, source_id: 9
|
||||
CONSTANT id: 17, constant_value: UInt64_1, constant_value_type: UInt8
|
||||
FUNCTION id: 18, function_name: tupleElement, function_type: ordinary, result_type: UInt64
|
||||
ARGUMENTS
|
||||
LIST id: 19, nodes: 2
|
||||
FUNCTION id: 20, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
|
||||
ARGUMENTS
|
||||
LIST id: 21, nodes: 1
|
||||
COLUMN id: 16, column_name: b, result_type: Int8, source_id: 9
|
||||
CONSTANT id: 22, constant_value: UInt64_2, constant_value_type: UInt8
|
||||
FUNCTION id: 23, function_name: divide, function_type: ordinary, result_type: Float64
|
||||
ARGUMENTS
|
||||
LIST id: 24, nodes: 2
|
||||
FUNCTION id: 25, function_name: tupleElement, function_type: ordinary, result_type: Int64
|
||||
ARGUMENTS
|
||||
LIST id: 26, nodes: 2
|
||||
FUNCTION id: 27, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
|
||||
ARGUMENTS
|
||||
LIST id: 28, nodes: 1
|
||||
COLUMN id: 16, column_name: b, result_type: Int8, source_id: 9
|
||||
CONSTANT id: 29, constant_value: UInt64_1, constant_value_type: UInt8
|
||||
FUNCTION id: 30, function_name: toFloat64, function_type: ordinary, result_type: Float64
|
||||
ARGUMENTS
|
||||
LIST id: 31, nodes: 1
|
||||
FUNCTION id: 32, function_name: tupleElement, function_type: ordinary, result_type: UInt64
|
||||
ARGUMENTS
|
||||
LIST id: 33, nodes: 2
|
||||
FUNCTION id: 27, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
|
||||
ARGUMENTS
|
||||
LIST id: 28, nodes: 1
|
||||
COLUMN id: 16, column_name: b, result_type: Int8, source_id: 9
|
||||
CONSTANT id: 34, constant_value: UInt64_2, constant_value_type: UInt8
|
||||
FUNCTION id: 35, function_name: tupleElement, function_type: ordinary, result_type: UInt64
|
||||
ARGUMENTS
|
||||
LIST id: 36, nodes: 2
|
||||
FUNCTION id: 37, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
|
||||
ARGUMENTS
|
||||
LIST id: 38, nodes: 1
|
||||
FUNCTION id: 39, function_name: plus, function_type: ordinary, result_type: Nullable(Int16)
|
||||
ARGUMENTS
|
||||
LIST id: 40, nodes: 2
|
||||
COLUMN id: 8, column_name: a, result_type: Nullable(Int8), source_id: 9
|
||||
CONSTANT id: 41, constant_value: UInt64_1, constant_value_type: UInt8
|
||||
CONSTANT id: 42, constant_value: UInt64_2, constant_value_type: UInt8
|
||||
FUNCTION id: 43, function_name: sum, function_type: aggregate, result_type: Nullable(Int64)
|
||||
ARGUMENTS
|
||||
LIST id: 44, nodes: 1
|
||||
FUNCTION id: 45, function_name: plus, function_type: ordinary, result_type: Nullable(Int16)
|
||||
ARGUMENTS
|
||||
LIST id: 46, nodes: 2
|
||||
COLUMN id: 8, column_name: a, result_type: Nullable(Int8), source_id: 9
|
||||
CONSTANT id: 47, constant_value: UInt64_2, constant_value_type: UInt8
|
||||
FUNCTION id: 48, function_name: count, function_type: aggregate, result_type: UInt64
|
||||
ARGUMENTS
|
||||
LIST id: 49, nodes: 1
|
||||
COLUMN id: 8, column_name: a, result_type: Nullable(Int8), source_id: 9
|
||||
JOIN TREE
|
||||
TABLE id: 9, table_name: default.fuse_tbl
|
||||
0 0 nan
|
||||
0 0 nan
|
||||
45 10 4.5 Decimal(38, 0) UInt64 Float64
|
||||
45 10 4.5 Decimal(38, 0) UInt64 Float64
|
||||
210 230 20
|
||||
SELECT
|
||||
sum(a),
|
||||
sumCount(b).1,
|
||||
sumCount(b).2
|
||||
FROM fuse_tbl
|
||||
---------NOT trigger fuse--------
|
||||
210 11.5
|
||||
SELECT
|
||||
sum(a),
|
||||
avg(b)
|
||||
FROM fuse_tbl
|
||||
|
@ -1,10 +1,9 @@
|
||||
SET allow_experimental_analyzer = 1;
|
||||
|
||||
DROP TABLE IF EXISTS fuse_tbl;
|
||||
CREATE TABLE fuse_tbl(a Nullable(Int8), b Int8) Engine = Log;
|
||||
INSERT INTO fuse_tbl VALUES (1, 1), (2, 2), (NULL, 3);
|
||||
CREATE TABLE fuse_tbl(a Int8, b Int8) Engine = Log;
|
||||
INSERT INTO fuse_tbl SELECT number, number + 1 FROM numbers(1, 20);
|
||||
|
||||
SET optimize_syntax_fuse_functions = 1;
|
||||
SET optimize_fuse_sum_count_avg = 1;
|
||||
|
||||
SELECT sum(a), sum(b), count(b) from fuse_tbl;
|
||||
EXPLAIN SYNTAX SELECT sum(a), sum(b), count(b) from fuse_tbl;
|
||||
@ -12,14 +11,4 @@ SELECT '---------NOT trigger fuse--------';
|
||||
SELECT sum(a), avg(b) from fuse_tbl;
|
||||
EXPLAIN SYNTAX SELECT sum(a), avg(b) from fuse_tbl;
|
||||
|
||||
SELECT sum(a + 1), sum(b), count(b), avg(b), count(a + 1), sum(a + 2), count(a) from fuse_tbl SETTINGS optimize_syntax_fuse_functions = 0;
|
||||
SELECT sum(a + 1), sum(b), count(b), avg(b), count(a + 1), sum(a + 2), count(a) from fuse_tbl;
|
||||
EXPLAIN QUERY TREE run_passes = 1 SELECT sum(a + 1), sum(b), count(b), avg(b), count(a + 1), sum(a + 2), count(a) from fuse_tbl;
|
||||
|
||||
SELECT sum(x), count(x), avg(x) FROM (SELECT number :: Decimal32(0) AS x FROM numbers(0)) SETTINGS optimize_syntax_fuse_functions = 0;
|
||||
SELECT sum(x), count(x), avg(x) FROM (SELECT number :: Decimal32(0) AS x FROM numbers(0));
|
||||
|
||||
SELECT sum(x), count(x), avg(x), toTypeName(sum(x)), toTypeName(count(x)), toTypeName(avg(x)) FROM (SELECT number :: Decimal32(0) AS x FROM numbers(10)) SETTINGS optimize_syntax_fuse_functions = 0;
|
||||
SELECT sum(x), count(x), avg(x), toTypeName(sum(x)), toTypeName(count(x)), toTypeName(avg(x)) FROM (SELECT number :: Decimal32(0) AS x FROM numbers(10));
|
||||
|
||||
DROP TABLE fuse_tbl;
|
||||
|
Loading…
Reference in New Issue
Block a user