fixes for FuseFunctionsPass

This commit is contained in:
vdimir 2022-11-07 14:24:27 +00:00
parent 5de29257e6
commit f732c97cf6
No known key found for this signature in database
GPG Key ID: 6EE4CE2BEDC51862
4 changed files with 231 additions and 93 deletions

View File

@ -38,9 +38,13 @@ public:
if (!function_node || !function_node->isAggregateFunction() || !matchFunctionName(function_node->getFunctionName()))
return;
if (function_node->getResultType()->isNullable())
/// Do not apply to functions with Nullable result type, because `sumCount` handles it different from `sum` and `avg`.
return;
const auto & arguments = function_node->getArgumentsNode()->getChildren();
if (arguments.size() != 1)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Aggregate function {} must have exactly one argument", function_node->getFunctionName());
throw Exception(ErrorCodes::LOGICAL_ERROR, "Aggregate function {} should have exactly one argument", function_node->getFunctionName());
mapping[QueryTreeNodeWithHash(arguments[0])].push_back(&node);
}
@ -111,12 +115,12 @@ void replaceWithSumCount(QueryTreeNodePtr & node, const FunctionNodePtr & sum_co
if (function_name == "sum")
{
assert(node->getResultType() == sum_count_result_type->getElement(0));
assert(node->getResultType()->equals(*sum_count_result_type->getElement(0)));
node = createTupleElementFunction(context, node->getResultType(), sum_count_node, 1);
}
else if (function_name == "count")
{
assert(node->getResultType() == sum_count_result_type->getElement(1));
assert(node->getResultType()->equals(*sum_count_result_type->getElement(1)));
node = createTupleElementFunction(context, node->getResultType(), sum_count_node, 2);
}
else if (function_name == "avg")
@ -148,6 +152,7 @@ void FuseFunctionsPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context
auto sum_count_node = createSumCoundNode(argument.node);
for (auto * node : nodes)
{
assert(node);
replaceWithSumCount(*node, sum_count_node, context);
}
}

View File

@ -67,7 +67,6 @@ public:
*
* TODO: Support _shard_num into shardNum() rewriting.
* TODO: Support logical expressions optimizer.
* TODO: Support fuse sum count optimize_fuse_sum_count_avg, optimize_syntax_fuse_functions.
* TODO: Support setting convert_query_to_cnf.
* TODO: Support setting optimize_using_constraints.
* TODO: Support setting optimize_substitute_columns.
@ -79,7 +78,7 @@ public:
* TODO: Support setting optimize_redundant_functions_in_order_by.
* TODO: Support setting optimize_monotonous_functions_in_order_by.
* TODO: Support setting optimize_if_transform_strings_to_enum.
* TODO: Support settings.optimize_syntax_fuse_functions.
* TODO: Support fuse quantile functions optimize_syntax_fuse_functions.
* TODO: Support settings.optimize_or_like_chain.
* TODO: Add optimizations based on function semantics. Example: SELECT * FROM test_table WHERE id != id. (id is not nullable column).
*/

View File

@ -1,5 +1,64 @@
1.5 3
\N \N
1.5 3
2 6
6 10 9
5 6 3 2 2 7 2
5 6 3 2 2 7 2
QUERY id: 0
PROJECTION COLUMNS
sum(a) Nullable(Int64)
avg(a) Nullable(Float64)
PROJECTION
LIST id: 1, nodes: 2
FUNCTION id: 2, function_name: sum, function_type: aggregate, result_type: Nullable(Int64)
ARGUMENTS
LIST id: 3, nodes: 1
COLUMN id: 4, column_name: a, result_type: Nullable(Int8), source_id: 5
FUNCTION id: 6, function_name: avg, function_type: aggregate, result_type: Nullable(Float64)
ARGUMENTS
LIST id: 7, nodes: 1
COLUMN id: 4, column_name: a, result_type: Nullable(Int8), source_id: 5
JOIN TREE
TABLE id: 5, table_name: default.fuse_tbl
QUERY id: 0
PROJECTION COLUMNS
sum(b) Int64
avg(b) Float64
PROJECTION
LIST id: 1, nodes: 2
FUNCTION id: 2, function_name: tupleElement, function_type: ordinary, result_type: Int64
ARGUMENTS
LIST id: 3, nodes: 2
FUNCTION id: 4, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 5, nodes: 1
COLUMN id: 6, column_name: b, result_type: Int8, source_id: 7
CONSTANT id: 8, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 9, function_name: divide, function_type: ordinary, result_type: Float64
ARGUMENTS
LIST id: 10, nodes: 2
FUNCTION id: 11, function_name: tupleElement, function_type: ordinary, result_type: Int64
ARGUMENTS
LIST id: 12, nodes: 2
FUNCTION id: 4, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 5, nodes: 1
COLUMN id: 6, column_name: b, result_type: Int8, source_id: 7
CONSTANT id: 13, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 14, function_name: toFloat64, function_type: ordinary, result_type: Float64
ARGUMENTS
LIST id: 15, nodes: 1
FUNCTION id: 16, function_name: tupleElement, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 17, nodes: 2
FUNCTION id: 4, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 5, nodes: 1
COLUMN id: 6, column_name: b, result_type: Int8, source_id: 7
CONSTANT id: 18, constant_value: UInt64_2, constant_value_type: UInt8
JOIN TREE
TABLE id: 7, table_name: default.fuse_tbl
QUERY id: 0
PROJECTION COLUMNS
sum(plus(a, 1)) Nullable(Int64)
@ -11,98 +70,158 @@ QUERY id: 0
count(a) UInt64
PROJECTION
LIST id: 1, nodes: 7
FUNCTION id: 2, function_name: tupleElement, function_type: ordinary, result_type: Nullable(Int64)
ARGUMENTS
LIST id: 3, nodes: 2
FUNCTION id: 4, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: Nullable(Int16)
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: a, result_type: Nullable(Int8), source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
CONSTANT id: 11, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 12, function_name: tupleElement, function_type: ordinary, result_type: Int64
ARGUMENTS
LIST id: 13, nodes: 2
FUNCTION id: 14, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 15, nodes: 1
COLUMN id: 16, column_name: b, result_type: Int8, source_id: 9
CONSTANT id: 17, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 18, function_name: tupleElement, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 19, nodes: 2
FUNCTION id: 14, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 15, nodes: 1
COLUMN id: 16, column_name: b, result_type: Int8, source_id: 9
CONSTANT id: 20, constant_value: UInt64_2, constant_value_type: UInt8
FUNCTION id: 21, function_name: divide, function_type: ordinary, result_type: Float64
ARGUMENTS
LIST id: 22, nodes: 2
FUNCTION id: 23, function_name: tupleElement, function_type: ordinary, result_type: Int64
ARGUMENTS
LIST id: 24, nodes: 2
FUNCTION id: 14, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 15, nodes: 1
COLUMN id: 16, column_name: b, result_type: Int8, source_id: 9
CONSTANT id: 25, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 26, function_name: toFloat64, function_type: ordinary, result_type: Float64
ARGUMENTS
LIST id: 27, nodes: 1
FUNCTION id: 28, function_name: tupleElement, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 29, nodes: 2
FUNCTION id: 14, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 15, nodes: 1
COLUMN id: 16, column_name: b, result_type: Int8, source_id: 9
CONSTANT id: 30, constant_value: UInt64_2, constant_value_type: UInt8
FUNCTION id: 31, function_name: tupleElement, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 32, nodes: 2
FUNCTION id: 4, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 5, nodes: 1
FUNCTION id: 6, function_name: plus, function_type: ordinary, result_type: Nullable(Int16)
ARGUMENTS
LIST id: 7, nodes: 2
COLUMN id: 8, column_name: a, result_type: Nullable(Int8), source_id: 9
CONSTANT id: 10, constant_value: UInt64_1, constant_value_type: UInt8
CONSTANT id: 33, constant_value: UInt64_2, constant_value_type: UInt8
FUNCTION id: 34, function_name: sum, function_type: aggregate, result_type: Nullable(Int64)
ARGUMENTS
LIST id: 35, nodes: 1
FUNCTION id: 36, function_name: plus, function_type: ordinary, result_type: Nullable(Int16)
ARGUMENTS
LIST id: 37, nodes: 2
COLUMN id: 8, column_name: a, result_type: Nullable(Int8), source_id: 9
CONSTANT id: 38, constant_value: UInt64_2, constant_value_type: UInt8
FUNCTION id: 39, function_name: count, function_type: aggregate, result_type: UInt64
ARGUMENTS
LIST id: 40, nodes: 1
COLUMN id: 8, column_name: a, result_type: Nullable(Int8), source_id: 9
JOIN TREE
TABLE id: 9, table_name: default.fuse_tbl
QUERY id: 0
PROJECTION COLUMNS
sum(a) Nullable(Int64)
avg(b) Float64
PROJECTION
LIST id: 1, nodes: 2
FUNCTION id: 2, function_name: sum, function_type: aggregate, result_type: Nullable(Int64)
ARGUMENTS
LIST id: 3, nodes: 1
COLUMN id: 4, column_name: a, result_type: Nullable(Int8), source_id: 5
FUNCTION id: 6, function_name: avg, function_type: aggregate, result_type: Float64
FUNCTION id: 4, function_name: plus, function_type: ordinary, result_type: Nullable(Int16)
ARGUMENTS
LIST id: 5, nodes: 2
COLUMN id: 6, column_name: a, result_type: Nullable(Int8), source_id: 7
CONSTANT id: 8, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 9, function_name: tupleElement, function_type: ordinary, result_type: Int64
ARGUMENTS
LIST id: 7, nodes: 1
COLUMN id: 8, column_name: b, result_type: Int8, source_id: 5
LIST id: 10, nodes: 2
FUNCTION id: 11, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 12, nodes: 1
COLUMN id: 13, column_name: b, result_type: Int8, source_id: 7
CONSTANT id: 14, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 15, function_name: tupleElement, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 16, nodes: 2
FUNCTION id: 11, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 12, nodes: 1
COLUMN id: 13, column_name: b, result_type: Int8, source_id: 7
CONSTANT id: 17, constant_value: UInt64_2, constant_value_type: UInt8
FUNCTION id: 18, function_name: divide, function_type: ordinary, result_type: Float64
ARGUMENTS
LIST id: 19, nodes: 2
FUNCTION id: 20, function_name: tupleElement, function_type: ordinary, result_type: Int64
ARGUMENTS
LIST id: 21, nodes: 2
FUNCTION id: 11, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 12, nodes: 1
COLUMN id: 13, column_name: b, result_type: Int8, source_id: 7
CONSTANT id: 22, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 23, function_name: toFloat64, function_type: ordinary, result_type: Float64
ARGUMENTS
LIST id: 24, nodes: 1
FUNCTION id: 25, function_name: tupleElement, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 26, nodes: 2
FUNCTION id: 11, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 12, nodes: 1
COLUMN id: 13, column_name: b, result_type: Int8, source_id: 7
CONSTANT id: 27, constant_value: UInt64_2, constant_value_type: UInt8
FUNCTION id: 28, function_name: count, function_type: aggregate, result_type: UInt64
ARGUMENTS
LIST id: 29, nodes: 1
FUNCTION id: 30, function_name: plus, function_type: ordinary, result_type: Nullable(Int16)
ARGUMENTS
LIST id: 31, nodes: 2
COLUMN id: 6, column_name: a, result_type: Nullable(Int8), source_id: 7
CONSTANT id: 32, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 33, function_name: sum, function_type: aggregate, result_type: Nullable(Int64)
ARGUMENTS
LIST id: 34, nodes: 1
FUNCTION id: 35, function_name: plus, function_type: ordinary, result_type: Nullable(Int16)
ARGUMENTS
LIST id: 36, nodes: 2
COLUMN id: 6, column_name: a, result_type: Nullable(Int8), source_id: 7
CONSTANT id: 37, constant_value: UInt64_2, constant_value_type: UInt8
FUNCTION id: 38, function_name: count, function_type: aggregate, result_type: UInt64
ARGUMENTS
LIST id: 39, nodes: 1
COLUMN id: 6, column_name: a, result_type: Nullable(Int8), source_id: 7
JOIN TREE
TABLE id: 5, table_name: default.fuse_tbl
TABLE id: 7, table_name: default.fuse_tbl
QUERY id: 0
PROJECTION COLUMNS
multiply(avg(b), 3) Float64
plus(plus(sum(b), 1), count(b)) Int64
multiply(count(b), count(b)) UInt64
PROJECTION
LIST id: 1, nodes: 3
FUNCTION id: 2, function_name: multiply, function_type: ordinary, result_type: Float64
ARGUMENTS
LIST id: 3, nodes: 2
FUNCTION id: 4, function_name: divide, function_type: ordinary, result_type: Float64
ARGUMENTS
LIST id: 5, nodes: 2
FUNCTION id: 6, function_name: tupleElement, function_type: ordinary, result_type: Int64
ARGUMENTS
LIST id: 7, nodes: 2
FUNCTION id: 8, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 9, nodes: 1
COLUMN id: 10, column_name: b, result_type: Int8, source_id: 11
CONSTANT id: 12, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 13, function_name: toFloat64, function_type: ordinary, result_type: Float64
ARGUMENTS
LIST id: 14, nodes: 1
FUNCTION id: 15, function_name: tupleElement, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 16, nodes: 2
FUNCTION id: 8, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 9, nodes: 1
COLUMN id: 10, column_name: b, result_type: Int8, source_id: 11
CONSTANT id: 17, constant_value: UInt64_2, constant_value_type: UInt8
CONSTANT id: 18, constant_value: UInt64_3, constant_value_type: UInt8
FUNCTION id: 19, function_name: plus, function_type: ordinary, result_type: Int64
ARGUMENTS
LIST id: 20, nodes: 2
FUNCTION id: 21, function_name: plus, function_type: ordinary, result_type: Int64
ARGUMENTS
LIST id: 22, nodes: 2
FUNCTION id: 23, function_name: tupleElement, function_type: ordinary, result_type: Int64
ARGUMENTS
LIST id: 24, nodes: 2
FUNCTION id: 8, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 9, nodes: 1
COLUMN id: 10, column_name: b, result_type: Int8, source_id: 11
CONSTANT id: 25, constant_value: UInt64_1, constant_value_type: UInt8
CONSTANT id: 26, constant_value: UInt64_1, constant_value_type: UInt8
FUNCTION id: 27, function_name: tupleElement, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 28, nodes: 2
FUNCTION id: 8, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 9, nodes: 1
COLUMN id: 10, column_name: b, result_type: Int8, source_id: 11
CONSTANT id: 29, constant_value: UInt64_2, constant_value_type: UInt8
FUNCTION id: 30, function_name: multiply, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 31, nodes: 2
FUNCTION id: 32, function_name: tupleElement, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 33, nodes: 2
FUNCTION id: 8, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 9, nodes: 1
COLUMN id: 10, column_name: b, result_type: Int8, source_id: 11
CONSTANT id: 34, constant_value: UInt64_2, constant_value_type: UInt8
FUNCTION id: 35, function_name: tupleElement, function_type: ordinary, result_type: UInt64
ARGUMENTS
LIST id: 36, nodes: 2
FUNCTION id: 8, function_name: sumCount, function_type: aggregate, result_type: Tuple(Int64, UInt64)
ARGUMENTS
LIST id: 9, nodes: 1
COLUMN id: 10, column_name: b, result_type: Int8, source_id: 11
CONSTANT id: 37, constant_value: UInt64_2, constant_value_type: UInt8
JOIN TREE
QUERY id: 11, is_subquery: 1
PROJECTION COLUMNS
b Int8
PROJECTION
LIST id: 38, nodes: 1
COLUMN id: 39, column_name: b, result_type: Int8, source_id: 40
JOIN TREE
TABLE id: 40, table_name: default.fuse_tbl
0 0 nan
0 0 nan
45 10 4.5 Decimal(38, 0) UInt64 Float64

View File

@ -7,11 +7,26 @@ CREATE TABLE fuse_tbl(a Nullable(Int8), b Int8) Engine = Log;
INSERT INTO fuse_tbl VALUES (1, 1), (2, 2), (NULL, 3);
SELECT avg(a), sum(a) FROM (SELECT a FROM fuse_tbl);
SELECT avg(a), sum(a) FROM (SELECT a FROM fuse_tbl WHERE isNull(a));
SELECT avg(a), sum(a) FROM (SELECT a FROM fuse_tbl WHERE isNotNull(a));
SELECT avg(b), sum(b) FROM (SELECT b FROM fuse_tbl);
SELECT avg(b) * 3, sum(b) + 1 + count(b), count(b) * count(b) FROM (SELECT b FROM fuse_tbl);
-- TODO(@vdimir): uncomment after https://github.com/ClickHouse/ClickHouse/pull/42865
-- SELECT sum(b), count(b) from (SELECT x as b FROM (SELECT sum(b) as x, count(b) FROM fuse_tbl) );
SELECT sum(a + 1), sum(b), count(b), avg(b), count(a + 1), sum(a + 2), count(a) from fuse_tbl SETTINGS optimize_syntax_fuse_functions = 0;
SELECT sum(a + 1), sum(b), count(b), avg(b), count(a + 1), sum(a + 2), count(a) from fuse_tbl;
EXPLAIN QUERY TREE run_passes = 1 SELECT sum(a), avg(a) from fuse_tbl;
EXPLAIN QUERY TREE run_passes = 1 SELECT sum(b), avg(b) from fuse_tbl;
EXPLAIN QUERY TREE run_passes = 1 SELECT sum(a + 1), sum(b), count(b), avg(b), count(a + 1), sum(a + 2), count(a) from fuse_tbl;
EXPLAIN QUERY TREE run_passes = 1 SELECT sum(a), avg(b) from fuse_tbl;
EXPLAIN QUERY TREE run_passes = 1 SELECT avg(b) * 3, sum(b) + 1 + count(b), count(b) * count(b) FROM (SELECT b FROM fuse_tbl);
-- TODO(@vdimir): uncomment after https://github.com/ClickHouse/ClickHouse/pull/42865
-- EXPLAIN QUERY TREE run_passes = 1 SELECT sum(b), count(b) from (SELECT x as b FROM (SELECT sum(b) as x, count(b) FROM fuse_tbl) );
SELECT sum(x), count(x), avg(x) FROM (SELECT number :: Decimal32(0) AS x FROM numbers(0)) SETTINGS optimize_syntax_fuse_functions = 0;
SELECT sum(x), count(x), avg(x) FROM (SELECT number :: Decimal32(0) AS x FROM numbers(0));